diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 6e5a471..0b2e2c6 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -5,6 +5,7 @@ from enum import Enum from pylops.utils import DTypeLike, NDArray +from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.backend import get_module, get_array_module, get_module_name @@ -78,7 +79,7 @@ class DistributedArray: axis : :obj:`int`, optional Axis along which distribution occurs. Defaults to ``0``. local_shapes : :obj:`list`, optional - List of tuples representing local shapes at each rank. + List of tuples or integers representing local shapes at each rank. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) dtype : :obj:`str`, optional @@ -88,7 +89,7 @@ class DistributedArray: def __init__(self, global_shape: Union[Tuple, Integral], base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD, partition: Partition = Partition.SCATTER, axis: int = 0, - local_shapes: Optional[List[Tuple]] = None, + local_shapes: Optional[List[Union[Tuple, Integral]]] = None, engine: Optional[str] = "numpy", dtype: Optional[DTypeLike] = np.float64): if isinstance(global_shape, Integral): @@ -100,10 +101,12 @@ def __init__(self, global_shape: Union[Tuple, Integral], raise ValueError(f"Should be either {Partition.BROADCAST} " f"or {Partition.SCATTER}") self.dtype = dtype - self._global_shape = global_shape + self._global_shape = _value_or_sized_to_tuple(global_shape) self._base_comm = base_comm self._partition = partition self._axis = axis + + local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] self._check_local_shapes(local_shapes) self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm, partition, axis)