Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 7, 2024
1 parent 3bc51bd commit 60d7619
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 22 deletions.
15 changes: 0 additions & 15 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,18 +831,3 @@ def chunked_nanfirst(darray, axis):

def chunked_nanlast(darray, axis):
return _chunked_first_or_last(darray, axis, op=nputils.nanlast)


def shuffle_array(array, indices: list[list[int]], axis: int):
# TODO: do chunk manager dance here.
if is_duck_dask_array(array):
if not module_available("dask", minversion="2024.08.0"):
raise ValueError(
"This method is very inefficient on dask<2024.08.0. Please upgrade."
)
# TODO: handle dimensions
return array.shuffle(indexer=indices, axis=axis)
else:
indexer = np.concatenate(indices)
# TODO: Do the array API thing here.
return np.take(array, indices=indexer, axis=axis)
12 changes: 7 additions & 5 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,6 @@ def shuffle(self) -> None:
"""
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.duck_array_ops import shuffle_array

(grouper,) = self.groupers
dim = self._group_dim
Expand All @@ -538,6 +537,8 @@ def shuffle(self) -> None:
if all(isinstance(idx, slice) for idx in self._group_indices):
return

indices: tuple[list[int]] = self._group_indices # type: ignore[assignment]

was_array = isinstance(self._obj, DataArray)
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj

Expand All @@ -546,21 +547,22 @@ def shuffle(self) -> None:
if dim not in var.dims:
shuffled[name] = var
continue
shuffled_data = shuffle_array(
var._data, list(self._group_indices), axis=var.get_axis_num(dim)
)
shuffled[name] = var._replace(data=shuffled_data)
shuffled[name] = var._shuffle(indices=list(indices), dim=dim)

# Replace self._group_indices with slices
slices = []
start = 0
for idxr in self._group_indices:
if TYPE_CHECKING:
assert not isinstance(idxr, slice)
slices.append(slice(start, start + len(idxr)))
start += len(idxr)
# TODO: we have now broken the invariant
# self._group_indices ≠ self.groupers[0].group_indices
self._group_indices = tuple(slices)
if was_array:
if TYPE_CHECKING:
assert isinstance(self._obj, DataArray)
self._obj = self._obj._from_temp_dataset(shuffled)
else:
self._obj = shuffled
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def copy(
ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]

GroupKey = Any
GroupIndex = Union[int, slice, list[int]]
GroupIndex = Union[slice, list[int]]
GroupIndices = tuple[GroupIndex, ...]
Bins = Union[
int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index
Expand Down
18 changes: 17 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@
maybe_coerce_to_str,
)
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import (
integer_types,
is_0d_dask_array,
is_chunked_array,
to_duck_array,
)
from xarray.util.deprecation_helpers import deprecate_dims

NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
Expand Down Expand Up @@ -998,6 +1004,16 @@ def compute(self, **kwargs):
new = self.copy(deep=False)
return new.load(**kwargs)

def _shuffle(self, indices: list[list[int]], dim: Hashable) -> Self:
array = self._data
if is_chunked_array(array):
chunkmanager = get_chunked_array_type(array)
return chunkmanager.shuffle(
array, indexer=indices, axis=self.get_axis_num(dim)
)
else:
return self.isel({dim: np.concatenate(indices)})

def isel(
self,
indexers: Mapping[Any, Any] | None = None,
Expand Down
9 changes: 9 additions & 0 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,12 @@ def store(
targets=targets,
**kwargs,
)

def shuffle(self, x: DaskArray, indexer: list[list[int]], axis: int) -> DaskArray:
import dask.array

if not module_available("dask", minversion="2024.08.0"):
raise ValueError(
"This method is very inefficient on dask<2024.08.0. Please upgrade."
)
return dask.array.shuffle(x, indexer, axis)
5 changes: 5 additions & 0 deletions xarray/namedarray/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ def compute(
"""
raise NotImplementedError()

def shuffle(
self, x: T_ChunkedArray, indexer: list[list[int]], axis: int
) -> T_ChunkedArray:
raise NotImplementedError()

@property
def array_api(self) -> Any:
"""
Expand Down

0 comments on commit 60d7619

Please sign in to comment.