diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index bdd858a5c50..b84ff466fa6 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -123,11 +123,11 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install "mypy<1.9" --force-reinstall + python -m pip install "mypy" --force-reinstall - name: Run mypy run: | - python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ + python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov uses: codecov/codecov-action@v4.5.0 @@ -138,7 +138,7 @@ jobs: name: codecov-umbrella fail_ci_if_error: false - mypy39: + mypy-min: name: Mypy 3.10 runs-on: "ubuntu-latest" needs: detect-ci-trigger @@ -177,32 +177,30 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install "mypy<1.9" --force-reinstall + python -m pip install "mypy" --force-reinstall - name: Run mypy run: | - python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ + python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov uses: codecov/codecov-action@v4.5.0 with: file: mypy_report/cobertura.xml - flags: mypy39 + flags: mypy-min env_vars: PYTHON_VERSION name: codecov-umbrella fail_ci_if_error: false - - pyright: name: Pyright runs-on: "ubuntu-latest" needs: detect-ci-trigger if: | - always() - && ( - contains( github.event.pull_request.labels.*.name, 'run-pyright') - ) + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) defaults: run: shell: bash -l {0} @@ -258,10 +256,10 @@ jobs: runs-on: "ubuntu-latest" needs: detect-ci-trigger if: | - always() - && ( - contains( github.event.pull_request.labels.*.name, 'run-pyright') - ) + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) defaults: run: shell: bash -l {0} @@ -312,8 +310,6 @@ jobs: name: codecov-umbrella fail_ci_if_error: false - - min-version-policy: name: Minimum Version Policy runs-on: "ubuntu-latest" diff --git a/.github/workflows/pypi-release.yaml b/.github/workflows/pypi-release.yaml index d50fff220a5..2eeabf469b7 100644 --- a/.github/workflows/pypi-release.yaml +++ b/.github/workflows/pypi-release.yaml @@ -88,7 +88,7 @@ jobs: path: dist - name: Publish package to TestPyPI if: github.event_name == 'push' - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.10.1 with: repository_url: https://test.pypi.org/legacy/ verbose: true @@ -111,6 +111,6 @@ jobs: name: releases path: dist - name: Publish package to PyPI - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.10.1 with: verbose: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 57f8b9e86f5..6ebd66bdf69 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,7 @@ # https://pre-commit.com/ ci: autoupdate_schedule: monthly + autoupdate_commit_msg: 'Update pre-commit hooks' exclude: 'xarray/datatree_.*' repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -13,7 +14,7 @@ repos: - id: mixed-line-ending - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.6.2' + rev: 'v0.6.3' hooks: - id: ruff args: ["--fix", "--show-fixes"] diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 0956be67dad..2661ec5cfba 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -724,7 +724,7 @@ class PerformanceBackend(xr.backends.BackendEntrypoint): def open_dataset( self, filename_or_obj: str | os.PathLike | None, - drop_variables: tuple[str] = None, + drop_variables: tuple[str, ...] = None, *, mask_and_scale=True, decode_times=True, diff --git a/asv_bench/benchmarks/datatree.py b/asv_bench/benchmarks/datatree.py new file mode 100644 index 00000000000..13eedd0a518 --- /dev/null +++ b/asv_bench/benchmarks/datatree.py @@ -0,0 +1,15 @@ +import xarray as xr +from xarray.core.datatree import DataTree + + +class Datatree: + def setup(self): + run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) + self.d_few = {"run1": run1} + self.d_many = {f"run{i}": run1.copy() for i in range(100)} + + def time_from_dict_few(self): + DataTree.from_dict(self.d_few) + + def time_from_dict_many(self): + DataTree.from_dict(self.d_many) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 065c1b3b17f..a5261d5106a 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -1,4 +1,5 @@ # import flox to avoid the cost of first import +import cftime import flox.xarray # noqa import numpy as np import pandas as pd @@ -96,7 +97,7 @@ def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds1d = self.ds1d.chunk({"dim_0": 50}).to_dataframe() + self.ds1d = self.ds1d.chunk({"dim_0": 50}).to_dask_dataframe() self.ds1d_mean = self.ds1d.groupby("b").mean().compute() def time_binary_op_2d(self): @@ -169,7 +170,21 @@ class GroupByLongTime: def setup(self, use_cftime, use_flox): arr = np.random.randn(10, 10, 365 * 30) time = xr.date_range("2000", periods=30 * 365, use_cftime=use_cftime) - self.da = xr.DataArray(arr, dims=("y", "x", "time"), coords={"time": time}) + + # GH9426 - deep-copying CFTime object arrays is weirdly slow + asda = xr.DataArray(time) + labeled_time = [] + for year, month in zip(asda.dt.year, asda.dt.month, strict=True): + labeled_time.append(cftime.datetime(year, month, 1)) + + self.da = xr.DataArray( + arr, + dims=("y", "x", "time"), + coords={"time": time, "time2": ("time", labeled_time)}, + ) + + def time_setup(self, use_cftime, use_flox): + self.da.groupby("time.month") def time_mean(self, use_cftime, use_flox): with xr.set_options(use_flox=use_flox): diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 579f4f00fbc..a19d17ff09a 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -64,7 +64,7 @@ def time_rolling_long(self, func, pandas, use_bottleneck): def time_rolling_np(self, window_, min_periods, use_bottleneck): with xr.set_options(use_bottleneck=use_bottleneck): self.ds.rolling(x=window_, center=False, min_periods=min_periods).reduce( - getattr(np, "nansum") + np.nansum ).load() @parameterized( diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index 0878222da35..d9590d95165 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -11,6 +11,6 @@ dependencies: - pytest-env - pytest-xdist - pytest-timeout - - numpy=1.23 + - numpy=1.24 - packaging=23.1 - - pandas=2.0 + - pandas=2.1 diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index ea1dc7b7fb0..b5a9176a62b 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -9,37 +9,37 @@ dependencies: # doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py. - python=3.10 - array-api-strict=1.0 # dependency for testing the array api compat - - boto3=1.26 + - boto3=1.28 - bottleneck=1.3 - - cartopy=0.21 + - cartopy=0.22 - cftime=1.6 - coveralls - - dask-core=2023.4 - - distributed=2023.4 + - dask-core=2023.9 + - distributed=2023.9 # Flox > 0.8 has a bug with numbagg versions # It will require numbagg > 0.6 # so we should just skip that series eventually # or keep flox pinned for longer than necessary - flox=0.7 - - h5netcdf=1.1 + - h5netcdf=1.2 # h5py and hdf5 tend to cause conflicts # for e.g. hdf5 1.12 conflicts with h5py=3.1 # prioritize bumping other packages instead - h5py=3.8 - hdf5=1.12 - hypothesis - - iris=3.4 + - iris=3.7 - lxml=4.9 # Optional dep of pydap - matplotlib-base=3.7 - nc-time-axis=1.4 # netcdf follows a 1.major.minor[.patch] convention # (see https://github.com/Unidata/netcdf4-python/issues/1090) - netcdf4=1.6.0 - - numba=0.56 + - numba=0.57 - numbagg=0.2.1 - - numpy=1.23 + - numpy=1.24 - packaging=23.1 - - pandas=2.0 + - pandas=2.1 - pint=0.22 - pip - pydap=3.4 @@ -49,9 +49,9 @@ dependencies: - pytest-xdist - pytest-timeout - rasterio=1.3 - - scipy=1.10 + - scipy=1.11 - seaborn=0.12 - sparse=0.14 - toolz=0.12 - - typing_extensions=4.5 - - zarr=2.14 + - typing_extensions=4.7 + - zarr=2.16 diff --git a/design_notes/flexible_indexes_notes.md b/design_notes/flexible_indexes_notes.md index b36ce3e46ed..f4a2c1c2125 100644 --- a/design_notes/flexible_indexes_notes.md +++ b/design_notes/flexible_indexes_notes.md @@ -71,7 +71,7 @@ An `XarrayIndex` subclass must/should/may implement the following properties/met - a `data` property to access index's data and map it to coordinate data (see [Section 4](#4-indexvariable)) - a `__getitem__()` implementation to propagate the index through DataArray/Dataset indexing operations - `equals()`, `union()` and `intersection()` methods for data alignment (see [Section 2.6](#26-using-indexes-for-data-alignment)) -- Xarray coordinate getters (see [Section 2.2.4](#224-implicit-coodinates)) +- Xarray coordinate getters (see [Section 2.2.4](#224-implicit-coordinates)) - a method that may return a new index and that will be called when one of the corresponding coordinates is dropped from the Dataset/DataArray (multi-coordinate indexes) - `encode()`/`decode()` methods that would allow storage-agnostic serialization and fast-path reconstruction of the underlying index object(s) (see [Section 2.8](#28-index-encoding)) - one or more "non-standard" methods or properties that could be leveraged in Xarray 3rd-party extensions like Dataset/DataArray accessors (see [Section 2.7](#27-using-indexes-for-other-purposes)) diff --git a/design_notes/grouper_objects.md b/design_notes/grouper_objects.md index af42ef2f493..508ed5e9716 100644 --- a/design_notes/grouper_objects.md +++ b/design_notes/grouper_objects.md @@ -166,7 +166,7 @@ where `|` represents chunk boundaries. A simple rechunking to ``` 000|111122|3333 ``` -would make this resampling reduction an embarassingly parallel blockwise problem. +would make this resampling reduction an embarrassingly parallel blockwise problem. Similarly consider monthly-mean climatologies for which the month numbers might be ``` diff --git a/design_notes/named_array_design_doc.md b/design_notes/named_array_design_doc.md index 074f8cf17e7..0050471cd01 100644 --- a/design_notes/named_array_design_doc.md +++ b/design_notes/named_array_design_doc.md @@ -258,7 +258,7 @@ Questions: Variable.coarsen_reshape Variable.rolling_window - Variable.set_dims # split this into broadcas_to and expand_dims + Variable.set_dims # split this into broadcast_to and expand_dims # Reordering/Reshaping diff --git a/doc/api.rst b/doc/api.rst index 6ed8d513934..d79c0612a98 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -529,9 +529,11 @@ Datetimelike properties DataArray.dt.quarter DataArray.dt.days_in_month DataArray.dt.daysinmonth + DataArray.dt.days_in_year DataArray.dt.season DataArray.dt.time DataArray.dt.date + DataArray.dt.decimal_year DataArray.dt.calendar DataArray.dt.is_month_start DataArray.dt.is_month_end diff --git a/doc/user-guide/dask.rst b/doc/user-guide/dask.rst index f71969066f9..5c421aa51d8 100644 --- a/doc/user-guide/dask.rst +++ b/doc/user-guide/dask.rst @@ -298,7 +298,7 @@ Automatic parallelization with ``apply_ufunc`` and ``map_blocks`` .. tip:: - Some problems can become embarassingly parallel and thus easy to parallelize + Some problems can become embarrassingly parallel and thus easy to parallelize automatically by rechunking to a frequency, e.g. ``ds.chunk(time=TimeResampler("YE"))``. See :py:meth:`Dataset.chunk` for more. @@ -559,7 +559,7 @@ larger chunksizes. .. tip:: - Many time domain problems become amenable to an embarassingly parallel or blockwise solution + Many time domain problems become amenable to an embarrassingly parallel or blockwise solution (e.g. using :py:func:`xarray.map_blocks`, :py:func:`dask.array.map_blocks`, or :py:func:`dask.array.blockwise`) by rechunking to a frequency along the time dimension. Provide :py:class:`xarray.groupers.TimeResampler` objects to :py:meth:`Dataset.chunk` to do so. diff --git a/doc/user-guide/data-structures.rst b/doc/user-guide/data-structures.rst index a1794f4123d..b963ccf0b00 100644 --- a/doc/user-guide/data-structures.rst +++ b/doc/user-guide/data-structures.rst @@ -289,7 +289,7 @@ pressure that were made under various conditions: * the measurements were made on four different days; * they were made at two separate locations, which we will represent using their latitude and longitude; and -* they were made using instruments by three different manufacutrers, which we +* they were made using instruments by three different manufacturers, which we will refer to as `'manufac1'`, `'manufac2'`, and `'manufac3'`. .. ipython:: python diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index c10ee6a659d..98bd7b4833b 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -305,6 +305,12 @@ Use grouper objects to group by multiple dimensions: from xarray.groupers import UniqueGrouper + da.groupby(["lat", "lon"]).sum() + +The above is sugar for using ``UniqueGrouper`` objects directly: + +.. ipython:: python + da.groupby(lat=UniqueGrouper(), lon=UniqueGrouper()).sum() diff --git a/doc/user-guide/pandas.rst b/doc/user-guide/pandas.rst index 26fa7ea5c0c..30939cbbd17 100644 --- a/doc/user-guide/pandas.rst +++ b/doc/user-guide/pandas.rst @@ -120,7 +120,7 @@ Particularly after a roundtrip, the following deviations are noted: - a non-dimension Dataset ``coordinate`` is converted into ``variable`` - a non-dimension DataArray ``coordinate`` is not converted -- ``dtype`` is not allways the same (e.g. "str" is converted to "object") +- ``dtype`` is not always the same (e.g. "str" is converted to "object") - ``attrs`` metadata is not conserved To avoid these problems, the third-party `ntv-pandas `__ library offers lossless and reversible conversions between diff --git a/doc/user-guide/testing.rst b/doc/user-guide/testing.rst index d82d9d7d7d9..434c0790139 100644 --- a/doc/user-guide/testing.rst +++ b/doc/user-guide/testing.rst @@ -193,7 +193,7 @@ different type: .. ipython:: python - def sparse_random_arrays(shape: tuple[int]) -> sparse._coo.core.COO: + def sparse_random_arrays(shape: tuple[int, ...]) -> sparse._coo.core.COO: """Strategy which generates random sparse.COO arrays""" if shape is None: shape = npst.array_shapes() diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 712ad68aeb3..b5e3ced3be8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,13 @@ v2024.07.1 (unreleased) New Features ~~~~~~~~~~~~ + +- Add :py:attr:`~core.accessor_dt.DatetimeAccessor.days_in_year` and :py:attr:`~core.accessor_dt.DatetimeAccessor.decimal_year` to the Datetime accessor on DataArrays. (:pull:`9105`). + By `Pascal Bourgault `_. + +Performance +~~~~~~~~~~~ + - Make chunk manager an option in ``set_options`` (:pull:`9362`). By `Tom White `_. - Support for :ref:`grouping by multiple variables `. @@ -35,6 +42,24 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ - Support for ``python 3.9`` has been dropped (:pull:`8937`) +- The minimum versions of some dependencies were changed + + ===================== ========= ======= + Package Old New + ===================== ========= ======= + boto3 1.26 1.28 + cartopy 0.21 0.22 + dask-core 2023.4 2023.9 + distributed 2023.4 2023.9 + h5netcdf 1.1 1.2 + iris 3.4 3.7 + numba 0.56 0.57 + numpy 1.23 1.24 + pandas 2.0 2.1 + scipy 1.10 1.11 + typing_extensions 4.5 4.7 + zarr 2.14 2.16 + ===================== ========= ======= Deprecations @@ -60,6 +85,16 @@ Bug fixes - Fix deprecation warning that was raised when calling ``np.array`` on an ``xr.DataArray`` in NumPy 2.0 (:issue:`9312`, :pull:`9393`) By `Andrew Scherer `_. +- Fix support for using ``pandas.DateOffset``, ``pandas.Timedelta``, and + ``datetime.timedelta`` objects as ``resample`` frequencies + (:issue:`9408`, :pull:`9413`). + By `Oliver Higgs `_. + +Performance +~~~~~~~~~~~ + +- Speed up grouping by avoiding deep-copy of non-dimension coordinates (:issue:`9426`, :pull:`9393`) + By `Deepak Cherian `_. Documentation ~~~~~~~~~~~~~ @@ -94,7 +129,7 @@ New Features (:issue:`6610`, :pull:`8840`). By `Deepak Cherian `_. - Allow rechunking to a frequency using ``Dataset.chunk(time=TimeResampler("YE"))`` syntax. (:issue:`7559`, :pull:`9109`) - Such rechunking allows many time domain analyses to be executed in an embarassingly parallel fashion. + Such rechunking allows many time domain analyses to be executed in an embarrassingly parallel fashion. By `Deepak Cherian `_. - Allow per-variable specification of ```mask_and_scale``, ``decode_times``, ``decode_timedelta`` ``use_cftime`` and ``concat_characters`` params in :py:func:`~xarray.open_dataset` (:pull:`9218`). @@ -127,7 +162,7 @@ Breaking changes Bug fixes ~~~~~~~~~ -- Fix scatter plot broadcasting unneccesarily. (:issue:`9129`, :pull:`9206`) +- Fix scatter plot broadcasting unnecessarily. (:issue:`9129`, :pull:`9206`) By `Jimmy Westling `_. - Don't convert custom indexes to ``pandas`` indexes when computing a diff (:pull:`9157`) By `Justus Magin `_. @@ -590,7 +625,7 @@ Internal Changes ~~~~~~~~~~~~~~~~ - The implementation of :py:func:`map_blocks` has changed to minimize graph size and duplication of data. - This should be a strict improvement even though the graphs are not always embarassingly parallel any more. + This should be a strict improvement even though the graphs are not always embarrassingly parallel any more. Please open an issue if you spot a regression. (:pull:`8412`, :issue:`8409`). By `Deepak Cherian `_. - Remove null values before plotting. (:pull:`8535`). diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 9e0d4640171..3f507e3f341 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -80,7 +80,7 @@ def test_roundtrip_dataarray(data, arr) -> None: tuple ) ) - coords = {name: np.arange(n) for (name, n) in zip(names, arr.shape)} + coords = {name: np.arange(n) for (name, n) in zip(names, arr.shape, strict=True)} original = xr.DataArray(arr, dims=names, coords=coords) roundtripped = xr.DataArray(original.to_pandas()) xr.testing.assert_identical(original, roundtripped) diff --git a/pyproject.toml b/pyproject.toml index de590559a64..9808dbf709a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,9 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "numpy>=1.23", + "numpy>=1.24", "packaging>=23.1", - "pandas>=2.0", + "pandas>=2.1", ] [project.optional-dependencies] @@ -84,14 +84,13 @@ source = ["xarray"] exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] -enable_error_code = "redundant-self" +enable_error_code = ["ignore-without-code", "redundant-self", "redundant-expr"] exclude = [ 'build', 'xarray/util/generate_.*\.py', 'xarray/datatree_/doc/.*\.py', ] files = "xarray" -show_error_codes = true show_error_context = true warn_redundant_casts = true warn_unused_configs = true @@ -236,12 +235,10 @@ reportMissingTypeStubs = false # ] [tool.ruff] -builtins = ["ellipsis"] extend-exclude = [ "doc", "_typed_ops.pyi", ] -target-version = "py310" [tool.ruff.lint] # E402: module level import not at top of file @@ -250,13 +247,13 @@ target-version = "py310" extend-safe-fixes = [ "TID252", # absolute imports ] -ignore = [ +extend-ignore = [ "E402", "E501", "E731", - "UP007" + "UP007", ] -select = [ +extend-select = [ "F", # Pyflakes "E", # Pycodestyle "W", diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2c95a7b6bf3..1f6b6076799 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -51,7 +51,7 @@ try: from dask.delayed import Delayed except ImportError: - Delayed = None # type: ignore + Delayed = None # type: ignore[assignment, misc] from io import BufferedIOBase from xarray.backends.common import BackendEntrypoint @@ -167,7 +167,7 @@ def check_name(name: Hashable): check_name(k) -def _validate_attrs(dataset, invalid_netcdf=False): +def _validate_attrs(dataset, engine, invalid_netcdf=False): """`attrs` must have a string key and a value which is either: a number, a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_. @@ -177,8 +177,8 @@ def _validate_attrs(dataset, invalid_netcdf=False): `invalid_netcdf=True`. """ - valid_types = (str, Number, np.ndarray, np.number, list, tuple) - if invalid_netcdf: + valid_types = (str, Number, np.ndarray, np.number, list, tuple, bytes) + if invalid_netcdf and engine == "h5netcdf": valid_types += (np.bool_,) def check_attr(name, value, valid_types): @@ -202,6 +202,23 @@ def check_attr(name, value, valid_types): f"{', '.join([vtype.__name__ for vtype in valid_types])}" ) + if isinstance(value, bytes) and engine == "h5netcdf": + try: + value.decode("utf-8") + except UnicodeDecodeError as e: + raise ValueError( + f"Invalid value provided for attribute '{name!r}': {value!r}. " + "Only binary data derived from UTF-8 encoded strings is allowed " + f"for the '{engine}' engine. Consider using the 'netcdf4' engine." + ) from e + + if b"\x00" in value: + raise ValueError( + f"Invalid value provided for attribute '{name!r}': {value!r}. " + f"Null characters are not permitted for the '{engine}' engine. " + "Consider using the 'netcdf4' engine." + ) + # Check attrs on the dataset itself for k, v in dataset.attrs.items(): check_attr(k, v, valid_types) @@ -1096,7 +1113,7 @@ def open_mfdataset( list(combined_ids_paths.keys()), list(combined_ids_paths.values()), ) - elif combine == "by_coords" and concat_dim is not None: + elif concat_dim is not None: raise ValueError( "When combine='by_coords', passing a value for `concat_dim` has no " "effect. To manually combine along a specific dimension you should " @@ -1353,7 +1370,7 @@ def to_netcdf( # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) - _validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf") + _validate_attrs(dataset, engine, invalid_netcdf) try: store_open = WRITEABLE_STORES[engine] @@ -1415,7 +1432,7 @@ def to_netcdf( store.sync() return target.getvalue() finally: - if not multifile and compute: + if not multifile and compute: # type: ignore[redundant-expr] store.close() if not compute: @@ -1568,8 +1585,9 @@ def save_mfdataset( multifile=True, **kwargs, ) - for ds, path, group in zip(datasets, paths, groups) - ] + for ds, path, group in zip(datasets, paths, groups, strict=True) + ], + strict=True, ) try: @@ -1583,7 +1601,10 @@ def save_mfdataset( import dask return dask.delayed( - [dask.delayed(_finalize_store)(w, s) for w, s in zip(writes, stores)] + [ + dask.delayed(_finalize_store)(w, s) + for w, s in zip(writes, stores, strict=True) + ] ) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 38cba9af212..dd169cdbc7e 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -431,7 +431,7 @@ def set_dimensions(self, variables, unlimited_dims=None): for v in unlimited_dims: # put unlimited_dims first dims[v] = None for v in variables.values(): - dims.update(dict(zip(v.dims, v.shape))) + dims.update(dict(zip(v.dims, v.shape, strict=True))) for dim, length in dims.items(): if dim in existing_dims and length != existing_dims[dim]: diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 86d84f532b1..9caaf013494 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -276,7 +276,7 @@ def __getstate__(self): def __setstate__(self, state) -> None: """Restore from a pickle.""" opener, args, mode, kwargs, lock, manager_id = state - self.__init__( # type: ignore + self.__init__( # type: ignore[misc] opener, *args, mode=mode, kwargs=kwargs, lock=lock, manager_id=manager_id ) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 0b7ebbbeb0c..b252d9136d2 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -208,7 +208,9 @@ def open_store_variable(self, name, var): "shuffle": var.shuffle, } if var.chunks: - encoding["preferred_chunks"] = dict(zip(var.dimensions, var.chunks)) + encoding["preferred_chunks"] = dict( + zip(var.dimensions, var.chunks, strict=True) + ) # Convert h5py-style compression options to NetCDF4-Python # style, if possible if var.compression == "gzip": diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index ec2fe25216a..af2c15495d7 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -278,7 +278,9 @@ def _extract_nc4_variable_encoding( chunksizes = encoding["chunksizes"] chunks_too_big = any( c > d and dim not in unlimited_dims - for c, d, dim in zip(chunksizes, variable.shape, variable.dims) + for c, d, dim in zip( + chunksizes, variable.shape, variable.dims, strict=False + ) ) has_original_shape = "original_shape" in encoding changed_shape = ( @@ -446,7 +448,9 @@ def open_store_variable(self, name: str, var): else: encoding["contiguous"] = False encoding["chunksizes"] = tuple(chunking) - encoding["preferred_chunks"] = dict(zip(var.dimensions, chunking)) + encoding["preferred_chunks"] = dict( + zip(var.dimensions, chunking, strict=True) + ) # TODO: figure out how to round-trip "endian-ness" without raising # warnings from netCDF4 # encoding['endian'] = var.endian() diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 5eb7f879ee5..8b707633a6d 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -199,7 +199,7 @@ def get_backend(engine: str | type[BackendEntrypoint]) -> BackendEntrypoint: "https://docs.xarray.dev/en/stable/getting-started-guide/installing.html" ) backend = engines[engine] - elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint): + elif issubclass(engine, BackendEntrypoint): backend = engine() else: raise TypeError( diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 242507f9c20..31b367a178b 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -186,7 +186,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): # TODO: incorporate synchronizer to allow writes from multiple dask # threads if var_chunks and enc_chunks_tuple: - for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks): + for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True): for dchunk in dchunks[:-1]: if dchunk % zchunk: base_error = ( @@ -548,13 +548,13 @@ def open_store_variable(self, name, zarr_array=None): encoding = { "chunks": zarr_array.chunks, - "preferred_chunks": dict(zip(dimensions, zarr_array.chunks)), + "preferred_chunks": dict(zip(dimensions, zarr_array.chunks, strict=True)), "compressor": zarr_array.compressor, "filters": zarr_array.filters, } # _FillValue needs to be in attributes, not encoding, so it will get # picked up by decode_cf - if getattr(zarr_array, "fill_value") is not None: + if zarr_array.fill_value is not None: attributes["_FillValue"] = zarr_array.fill_value return Variable(dimensions, data, attributes, encoding) @@ -576,7 +576,7 @@ def get_dimensions(self): dimensions = {} for k, v in self.zarr_group.arrays(): dim_names, _ = _get_zarr_dims_and_attrs(v, DIMENSION_KEY, try_nczarr) - for d, s in zip(dim_names, v.shape): + for d, s in zip(dim_names, v.shape, strict=True): if d in dimensions and dimensions[d] != s: raise ValueError( f"found conflicting lengths for dimension {d} " diff --git a/xarray/coding/calendar_ops.py b/xarray/coding/calendar_ops.py index 1b2875b26f1..22a19a63871 100644 --- a/xarray/coding/calendar_ops.py +++ b/xarray/coding/calendar_ops.py @@ -9,7 +9,12 @@ _should_cftime_be_used, convert_times, ) -from xarray.core.common import _contains_datetime_like_objects, is_np_datetime_like +from xarray.core.common import ( + _contains_datetime_like_objects, + full_like, + is_np_datetime_like, +) +from xarray.core.computation import apply_ufunc try: import cftime @@ -25,16 +30,6 @@ ] -def _days_in_year(year, calendar, use_cftime=True): - """Return the number of days in the input year according to the input calendar.""" - date_type = get_date_type(calendar, use_cftime=use_cftime) - if year == -1 and calendar in _CALENDARS_WITHOUT_YEAR_ZERO: - difference = date_type(year + 2, 1, 1) - date_type(year, 1, 1) - else: - difference = date_type(year + 1, 1, 1) - date_type(year, 1, 1) - return difference.days - - def convert_calendar( obj, calendar, @@ -191,11 +186,7 @@ def convert_calendar( # Special case for conversion involving 360_day calendar if align_on == "year": # Instead of translating dates directly, this tries to keep the position within a year similar. - new_doy = time.groupby(f"{dim}.year").map( - _interpolate_day_of_year, - target_calendar=calendar, - use_cftime=use_cftime, - ) + new_doy = _interpolate_day_of_year(time, target_calendar=calendar) elif align_on == "random": # The 5 days to remove are randomly chosen, one for each of the five 72-days periods of the year. new_doy = time.groupby(f"{dim}.year").map( @@ -207,7 +198,7 @@ def convert_calendar( _convert_to_new_calendar_with_new_day_of_year( date, newdoy, calendar, use_cftime ) - for date, newdoy in zip(time.variable._data.array, new_doy) + for date, newdoy in zip(time.variable._data.array, new_doy, strict=True) ], dims=(dim,), name=dim, @@ -242,16 +233,25 @@ def convert_calendar( return out -def _interpolate_day_of_year(time, target_calendar, use_cftime): - """Returns the nearest day in the target calendar of the corresponding - "decimal year" in the source calendar. - """ - year = int(time.dt.year[0]) - source_calendar = time.dt.calendar +def _is_leap_year(years, calendar): + func = np.vectorize(cftime.is_leap_year) + return func(years, calendar=calendar) + + +def _days_in_year(years, calendar): + """The number of days in the year according to given calendar.""" + if calendar == "360_day": + return full_like(years, 360) + return _is_leap_year(years, calendar).astype(int) + 365 + + +def _interpolate_day_of_year(times, target_calendar): + """Returns the nearest day in the target calendar of the corresponding "decimal year" in the source calendar.""" + source_calendar = times.dt.calendar return np.round( - _days_in_year(year, target_calendar, use_cftime) - * time.dt.dayofyear - / _days_in_year(year, source_calendar, use_cftime) + _days_in_year(times.dt.year, target_calendar) + * times.dt.dayofyear + / _days_in_year(times.dt.year, source_calendar) ).astype(int) @@ -260,18 +260,18 @@ def _random_day_of_year(time, target_calendar, use_cftime): Removes Feb 29th and five other days chosen randomly within five sections of 72 days. """ - year = int(time.dt.year[0]) + year = time.dt.year[0] source_calendar = time.dt.calendar new_doy = np.arange(360) + 1 rm_idx = np.random.default_rng().integers(0, 72, 5) + 72 * np.arange(5) if source_calendar == "360_day": for idx in rm_idx: new_doy[idx + 1 :] = new_doy[idx + 1 :] + 1 - if _days_in_year(year, target_calendar, use_cftime) == 366: + if _days_in_year(year, target_calendar) == 366: new_doy[new_doy >= 60] = new_doy[new_doy >= 60] + 1 elif target_calendar == "360_day": new_doy = np.insert(new_doy, rm_idx - np.arange(5), -1) - if _days_in_year(year, source_calendar, use_cftime) == 366: + if _days_in_year(year, source_calendar) == 366: new_doy = np.insert(new_doy, 60, -1) return new_doy[time.dt.dayofyear - 1] @@ -304,32 +304,45 @@ def _convert_to_new_calendar_with_new_day_of_year( return np.nan -def _datetime_to_decimal_year(times, dim="time", calendar=None): - """Convert a datetime DataArray to decimal years according to its calendar or the given one. +def _decimal_year_cftime(time, year, days_in_year, *, date_class): + year_start = date_class(year, 1, 1) + delta = np.timedelta64(time - year_start, "ns") + days_in_year = np.timedelta64(days_in_year, "D") + return year + delta / days_in_year + + +def _decimal_year_numpy(time, year, days_in_year, *, dtype): + time = np.asarray(time).astype(dtype) + year_start = np.datetime64(int(year) - 1970, "Y").astype(dtype) + delta = time - year_start + days_in_year = np.timedelta64(days_in_year, "D") + return year + delta / days_in_year + + +def _decimal_year(times): + """Convert a datetime DataArray to decimal years according to its calendar. The decimal year of a timestamp is its year plus its sub-year component converted to the fraction of its year. Ex: '2000-03-01 12:00' is 2000.1653 in a standard calendar, 2000.16301 in a "noleap" or 2000.16806 in a "360_day". """ - from xarray.core.dataarray import DataArray - - calendar = calendar or times.dt.calendar - - if is_np_datetime_like(times.dtype): - times = times.copy(data=convert_times(times.values, get_date_type("standard"))) - - def _make_index(time): - year = int(time.dt.year[0]) - doys = cftime.date2num(time, f"days since {year:04d}-01-01", calendar=calendar) - return DataArray( - year + doys / _days_in_year(year, calendar), - dims=(dim,), - coords=time.coords, - name=dim, - ) - - return times.groupby(f"{dim}.year").map(_make_index) + if times.dtype == "O": + function = _decimal_year_cftime + kwargs = {"date_class": get_date_type(times.dt.calendar, True)} + else: + function = _decimal_year_numpy + kwargs = {"dtype": times.dtype} + return apply_ufunc( + function, + times, + times.dt.year, + times.dt.days_in_year, + kwargs=kwargs, + vectorize=True, + dask="parallelized", + output_dtypes=[np.float64], + ) def interp_calendar(source, target, dim="time"): @@ -372,9 +385,7 @@ def interp_calendar(source, target, dim="time"): f"Both 'source.{dim}' and 'target' must contain datetime objects." ) - source_calendar = source[dim].dt.calendar target_calendar = target.dt.calendar - if ( source[dim].time.dt.year == 0 ).any() and target_calendar in _CALENDARS_WITHOUT_YEAR_ZERO: @@ -383,8 +394,8 @@ def interp_calendar(source, target, dim="time"): ) out = source.copy() - out[dim] = _datetime_to_decimal_year(source[dim], dim=dim, calendar=source_calendar) - target_idx = _datetime_to_decimal_year(target, dim=dim, calendar=target_calendar) + out[dim] = _decimal_year(source[dim]) + target_idx = _decimal_year(target) out = out.interp(**{dim: target_idx}) out[dim] = target return out diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index f7bed2c13ef..c503e8ebcd3 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -47,7 +47,7 @@ from collections.abc import Mapping from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, ClassVar, Literal +from typing import TYPE_CHECKING, ClassVar, Literal, TypeVar import numpy as np import pandas as pd @@ -80,6 +80,7 @@ DayOption: TypeAlias = Literal["start", "end"] +T_FreqStr = TypeVar("T_FreqStr", str, None) def _nanosecond_precision_timestamp(*args, **kwargs): @@ -739,7 +740,7 @@ def _generate_anchored_deprecated_frequencies( return pairs -_DEPRECATED_FREQUENICES: dict[str, str] = { +_DEPRECATED_FREQUENCIES: dict[str, str] = { "A": "YE", "Y": "YE", "AS": "YS", @@ -765,18 +766,25 @@ def _generate_anchored_deprecated_frequencies( def _emit_freq_deprecation_warning(deprecated_freq): - recommended_freq = _DEPRECATED_FREQUENICES[deprecated_freq] + recommended_freq = _DEPRECATED_FREQUENCIES[deprecated_freq] message = _DEPRECATION_MESSAGE.format( deprecated_freq=deprecated_freq, recommended_freq=recommended_freq ) emit_user_level_warning(message, FutureWarning) -def to_offset(freq: BaseCFTimeOffset | str, warn: bool = True) -> BaseCFTimeOffset: +def to_offset( + freq: BaseCFTimeOffset | str | timedelta | pd.Timedelta | pd.DateOffset, + warn: bool = True, +) -> BaseCFTimeOffset: """Convert a frequency string to the appropriate subclass of BaseCFTimeOffset.""" if isinstance(freq, BaseCFTimeOffset): return freq + if isinstance(freq, timedelta | pd.Timedelta): + return delta_to_tick(freq) + if isinstance(freq, pd.DateOffset): + freq = _legacy_to_new_freq(freq.freqstr) match = re.match(_PATTERN, freq) if match is None: @@ -784,13 +792,41 @@ def to_offset(freq: BaseCFTimeOffset | str, warn: bool = True) -> BaseCFTimeOffs freq_data = match.groupdict() freq = freq_data["freq"] - if warn and freq in _DEPRECATED_FREQUENICES: + if warn and freq in _DEPRECATED_FREQUENCIES: _emit_freq_deprecation_warning(freq) multiples = freq_data["multiple"] multiples = 1 if multiples is None else int(multiples) return _FREQUENCIES[freq](n=multiples) +def delta_to_tick(delta: timedelta | pd.Timedelta) -> Tick: + """Adapted from pandas.tslib.delta_to_tick""" + if isinstance(delta, pd.Timedelta) and delta.nanoseconds != 0: + # pandas.Timedelta has nanoseconds, but these are not supported + raise ValueError( + "Unable to convert 'pandas.Timedelta' object with non-zero " + "nanoseconds to 'CFTimeOffset' object" + ) + if delta.microseconds == 0: + if delta.seconds == 0: + return Day(n=delta.days) + else: + seconds = delta.days * 86400 + delta.seconds + if seconds % 3600 == 0: + return Hour(n=seconds // 3600) + elif seconds % 60 == 0: + return Minute(n=seconds // 60) + else: + return Second(n=seconds) + else: + # Regardless of the days and seconds this will always be a Millisecond + # or Microsecond object + if delta.microseconds % 1_000 == 0: + return Millisecond(n=delta.microseconds // 1_000) + else: + return Microsecond(n=delta.microseconds) + + def to_cftime_datetime(date_str_or_date, calendar=None): if cftime is None: raise ModuleNotFoundError("No module named 'cftime'") @@ -1332,7 +1368,7 @@ def _new_to_legacy_freq(freq): return freq -def _legacy_to_new_freq(freq): +def _legacy_to_new_freq(freq: T_FreqStr) -> T_FreqStr: # to avoid internal deprecation warnings when freq is determined using pandas < 2.2 # TODO: remove once requiring pandas >= 2.2 diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index ae08abb06a7..d3a0fbb3dba 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -801,6 +801,11 @@ def round(self, freq): """ return self._round_via_method(freq, _round_to_nearest_half_even) + @property + def is_leap_year(self): + func = np.vectorize(cftime.is_leap_year) + return func(self.year, calendar=self.calendar) + def _parse_iso8601_without_reso(date_type, datetime_str): date, _ = _parse_iso8601_with_reso(date_type, datetime_str) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 70df8c6c390..cfdecd28a27 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -204,7 +204,7 @@ def _unpack_time_units_and_ref_date(units: str) -> tuple[str, pd.Timestamp]: def _decode_cf_datetime_dtype( - data, units: str, calendar: str, use_cftime: bool | None + data, units: str, calendar: str | None, use_cftime: bool | None ) -> np.dtype: # Verify that at least the first and last date can be decoded # successfully. Otherwise, tracebacks end up swallowed by @@ -704,7 +704,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray: def encode_cf_datetime( - dates: T_DuckArray, # type: ignore + dates: T_DuckArray, # type: ignore[misc] units: str | None = None, calendar: str | None = None, dtype: np.dtype | None = None, @@ -726,7 +726,7 @@ def encode_cf_datetime( def _eagerly_encode_cf_datetime( - dates: T_DuckArray, # type: ignore + dates: T_DuckArray, # type: ignore[misc] units: str | None = None, calendar: str | None = None, dtype: np.dtype | None = None, @@ -809,7 +809,7 @@ def _eagerly_encode_cf_datetime( def _encode_cf_datetime_within_map_blocks( - dates: T_DuckArray, # type: ignore + dates: T_DuckArray, # type: ignore[misc] units: str, calendar: str, dtype: np.dtype, @@ -859,7 +859,7 @@ def _lazily_encode_cf_datetime( def encode_cf_timedelta( - timedeltas: T_DuckArray, # type: ignore + timedeltas: T_DuckArray, # type: ignore[misc] units: str | None = None, dtype: np.dtype | None = None, ) -> tuple[T_DuckArray, str]: @@ -871,7 +871,7 @@ def encode_cf_timedelta( def _eagerly_encode_cf_timedelta( - timedeltas: T_DuckArray, # type: ignore + timedeltas: T_DuckArray, # type: ignore[misc] units: str | None = None, dtype: np.dtype | None = None, allow_units_modification: bool = True, @@ -923,7 +923,7 @@ def _eagerly_encode_cf_timedelta( def _encode_cf_timedelta_within_map_blocks( - timedeltas: T_DuckArray, # type:ignore + timedeltas: T_DuckArray, # type: ignore[misc] units: str, dtype: np.dtype, ) -> T_DuckArray: diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 41b982d268b..e73893d0f35 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -6,15 +6,17 @@ import numpy as np import pandas as pd +from xarray.coding.calendar_ops import _decimal_year from xarray.coding.times import infer_calendar_name from xarray.core import duck_array_ops from xarray.core.common import ( _contains_datetime_like_objects, + full_like, is_np_datetime_like, is_np_timedelta_like, ) from xarray.core.types import T_DataArray -from xarray.core.variable import IndexVariable +from xarray.core.variable import IndexVariable, Variable from xarray.namedarray.utils import is_duck_dask_array if TYPE_CHECKING: @@ -244,12 +246,22 @@ def _date_field(self, name: str, dtype: DTypeLike) -> T_DataArray: if dtype is None: dtype = self._obj.dtype result = _get_date_field(_index_or_data(self._obj), name, dtype) - newvar = self._obj.variable.copy(data=result, deep=False) + newvar = Variable( + dims=self._obj.dims, + attrs=self._obj.attrs, + encoding=self._obj.encoding, + data=result, + ) return self._obj._replace(newvar, name=name) def _tslib_round_accessor(self, name: str, freq: str) -> T_DataArray: result = _round_field(_index_or_data(self._obj), name, freq) - newvar = self._obj.variable.copy(data=result, deep=False) + newvar = Variable( + dims=self._obj.dims, + attrs=self._obj.attrs, + encoding=self._obj.encoding, + data=result, + ) return self._obj._replace(newvar, name=name) def floor(self, freq: str) -> T_DataArray: @@ -533,6 +545,33 @@ def calendar(self) -> CFCalendar: """ return infer_calendar_name(self._obj.data) + @property + def days_in_year(self) -> T_DataArray: + """Each datetime as the year plus the fraction of the year elapsed.""" + if self.calendar == "360_day": + result = full_like(self.year, 360) + else: + result = self.is_leap_year.astype(int) + 365 + newvar = Variable( + dims=self._obj.dims, + attrs=self._obj.attrs, + encoding=self._obj.encoding, + data=result, + ) + return self._obj._replace(newvar, name="days_in_year") + + @property + def decimal_year(self) -> T_DataArray: + """Convert the dates as a fractional year.""" + result = _decimal_year(self._obj) + newvar = Variable( + dims=self._obj.dims, + attrs=self._obj.attrs, + encoding=self._obj.encoding, + data=result, + ) + return self._obj._replace(newvar, name="decimal_year") + class TimedeltaAccessor(TimeAccessor[T_DataArray]): """Access Timedelta fields for DataArrays with Timedelta-like dtypes. diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index c1a2d958c83..e44ef75a88b 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -847,7 +847,7 @@ def normalize( normalized : same type as values """ - return self._apply(func=lambda x: normalize(form, x)) + return self._apply(func=lambda x: normalize(form, x)) # type: ignore[arg-type] def isalnum(self) -> T_DataArray: """ diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index a28376d2890..d6cdd45bb49 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -405,6 +405,7 @@ def align_indexes(self) -> None: zip( [joined_index] + matching_indexes, [joined_index_vars] + matching_index_vars, + strict=True, ) ) need_reindex = self._need_reindex(dims, cmp_indexes) @@ -412,7 +413,7 @@ def align_indexes(self) -> None: if len(matching_indexes) > 1: need_reindex = self._need_reindex( dims, - list(zip(matching_indexes, matching_index_vars)), + list(zip(matching_indexes, matching_index_vars, strict=True)), ) else: need_reindex = False @@ -557,7 +558,7 @@ def reindex_all(self) -> None: self.results = tuple( self._reindex_one(obj, matching_indexes) for obj, matching_indexes in zip( - self.objects, self.objects_matching_indexes + self.objects, self.objects_matching_indexes, strict=True ) ) @@ -952,7 +953,7 @@ def is_alignable(obj): fill_value=fill_value, ) - for position, key, aligned_obj in zip(positions, keys, aligned): + for position, key, aligned_obj in zip(positions, keys, aligned, strict=True): if key is no_key: out[position] = aligned_obj else: diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 4b4a07ddc77..c7dff9d249d 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -139,7 +139,8 @@ def _infer_concat_order_from_coords(datasets): # Append positions along extra dimension to structure which # encodes the multi-dimensional concatenation order tile_ids = [ - tile_id + (position,) for tile_id, position in zip(tile_ids, order) + tile_id + (position,) + for tile_id, position in zip(tile_ids, order, strict=True) ] if len(datasets) > 1 and not concat_dims: @@ -148,7 +149,7 @@ def _infer_concat_order_from_coords(datasets): "order the datasets for concatenation" ) - combined_ids = dict(zip(tile_ids, datasets)) + combined_ids = dict(zip(tile_ids, datasets, strict=True)) return combined_ids, concat_dims @@ -349,7 +350,7 @@ def _nested_combine( combined_ids = _infer_concat_order_from_positions(datasets) else: # Already sorted so just use the ids already passed - combined_ids = dict(zip(ids, datasets)) + combined_ids = dict(zip(ids, datasets, strict=True)) # Check that the inferred shape is combinable _check_shape_tile_ids(combined_ids) diff --git a/xarray/core/common.py b/xarray/core/common.py index 74c03f9baf5..f043b7be3dd 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import warnings from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping from contextlib import suppress @@ -13,6 +14,7 @@ from xarray.core import dtypes, duck_array_ops, formatting, formatting_html, ops from xarray.core.indexing import BasicIndexer, ExplicitlyIndexed from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import ResampleCompatible from xarray.core.utils import ( Frozen, either_dict_or_kwargs, @@ -32,8 +34,6 @@ if TYPE_CHECKING: - import datetime - from numpy.typing import DTypeLike from xarray.core.dataarray import DataArray @@ -254,7 +254,7 @@ def sizes(self: Any) -> Mapping[Hashable, int]: -------- Dataset.sizes """ - return Frozen(dict(zip(self.dims, self.shape))) + return Frozen(dict(zip(self.dims, self.shape, strict=True))) class AttrAccessMixin: @@ -891,14 +891,14 @@ def rolling_exp( def _resample( self, resample_cls: type[T_Resample], - indexer: Mapping[Hashable, str | Resampler] | None, + indexer: Mapping[Hashable, ResampleCompatible | Resampler] | None, skipna: bool | None, closed: SideOptions | None, label: SideOptions | None, offset: pd.Timedelta | datetime.timedelta | str | None, origin: str | DatetimeLike, restore_coord_dims: bool | None, - **indexer_kwargs: str | Resampler, + **indexer_kwargs: ResampleCompatible | Resampler, ) -> T_Resample: """Returns a Resample object for performing resampling operations. @@ -1078,14 +1078,18 @@ def _resample( ) grouper: Resampler - if isinstance(freq, str): + if isinstance(freq, ResampleCompatible): grouper = TimeResampler( freq=freq, closed=closed, label=label, origin=origin, offset=offset ) elif isinstance(freq, Resampler): grouper = freq else: - raise ValueError("freq must be a str or a Resampler object") + raise ValueError( + "freq must be an object of type 'str', 'datetime.timedelta', " + "'pandas.Timedelta', 'pandas.DateOffset', or 'TimeResampler'. " + f"Received {type(freq)} instead." + ) rgrouper = ResolvedGrouper(grouper, group, self) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 3e91efc1ede..91a184d55cd 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -71,9 +71,9 @@ class _UFuncSignature: Attributes ---------- - input_core_dims : tuple[tuple] + input_core_dims : tuple[tuple, ...] Core dimension names on each input variable. - output_core_dims : tuple[tuple] + output_core_dims : tuple[tuple, ...] Core dimension names on each output variable. """ @@ -326,7 +326,7 @@ def apply_dataarray_vfunc( variable, coords=coords, indexes=indexes, name=name, fastpath=True ) for variable, coords, indexes in zip( - result_var, result_coords, result_indexes + result_var, result_coords, result_indexes, strict=True ) ) else: @@ -407,7 +407,7 @@ def _unpack_dict_tuples( ) -> tuple[dict[Hashable, Variable], ...]: out: tuple[dict[Hashable, Variable], ...] = tuple({} for _ in range(num_outputs)) for name, values in result_vars.items(): - for value, results_dict in zip(values, out): + for value, results_dict in zip(values, out, strict=True): results_dict[name] = value return out @@ -422,7 +422,7 @@ def _check_core_dims(signature, variable_args, name): """ missing = [] for i, (core_dims, variable_arg) in enumerate( - zip(signature.input_core_dims, variable_args) + zip(signature.input_core_dims, variable_args, strict=True) ): # Check whether all the dims are on the variable. Note that we need the # `hasattr` to check for a dims property, to protect against the case where @@ -454,7 +454,7 @@ def apply_dict_of_variables_vfunc( grouped_by_name = collect_dict_values(args, names, fill_value) result_vars = {} - for name, variable_args in zip(names, grouped_by_name): + for name, variable_args in zip(names, grouped_by_name, strict=True): core_dim_present = _check_core_dims(signature, variable_args, name) if core_dim_present is True: result_vars[name] = func(*variable_args) @@ -546,7 +546,7 @@ def apply_dataset_vfunc( if signature.num_outputs > 1: out = tuple( _fast_dataset(*args) - for args in zip(result_vars, list_of_coords, list_of_indexes) + for args in zip(result_vars, list_of_coords, list_of_indexes, strict=True) ) else: (coord_vars,) = list_of_coords @@ -616,11 +616,13 @@ def apply_groupby_func(func, *args): iterator = itertools.repeat(arg) iterators.append(iterator) - applied: Iterator = (func(*zipped_args) for zipped_args in zip(*iterators)) + applied: Iterator = ( + func(*zipped_args) for zipped_args in zip(*iterators, strict=False) + ) applied_example, applied = peek_at(applied) combine = first_groupby._combine # type: ignore[attr-defined] if isinstance(applied_example, tuple): - combined = tuple(combine(output) for output in zip(*applied)) + combined = tuple(combine(output) for output in zip(*applied, strict=True)) else: combined = combine(applied) return combined @@ -637,7 +639,7 @@ def unified_dim_sizes( "broadcasting cannot handle duplicate " f"dimensions on a variable: {list(var.dims)}" ) - for dim, size in zip(var.dims, var.shape): + for dim, size in zip(var.dims, var.shape, strict=True): if dim not in exclude_dims: if dim not in dim_sizes: dim_sizes[dim] = size @@ -741,7 +743,7 @@ def apply_variable_ufunc( if isinstance(arg, Variable) else arg ) - for arg, core_dims in zip(args, signature.input_core_dims) + for arg, core_dims in zip(args, signature.input_core_dims, strict=True) ] if any(is_chunked_array(array) for array in input_data): @@ -766,7 +768,7 @@ def apply_variable_ufunc( allow_rechunk = dask_gufunc_kwargs.get("allow_rechunk", None) if allow_rechunk is None: for n, (data, core_dims) in enumerate( - zip(input_data, signature.input_core_dims) + zip(input_data, signature.input_core_dims, strict=True) ): if is_chunked_array(data): # core dimensions cannot span multiple chunks @@ -848,7 +850,7 @@ def func(*arrays): ) output: list[Variable] = [] - for dims, data in zip(output_dims, result_data): + for dims, data in zip(output_dims, result_data, strict=True): data = as_compatible_data(data) if data.ndim != len(dims): raise ValueError( @@ -2179,7 +2181,7 @@ def _calc_idxminmax( # Handle chunked arrays (e.g. dask). if is_chunked_array(array.data): chunkmanager = get_chunked_array_type(array.data) - chunks = dict(zip(array.dims, array.chunks)) + chunks = dict(zip(array.dims, array.chunks, strict=True)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) data = dask_coord[duck_array_ops.ravel(indx.data)] res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) @@ -2268,7 +2270,7 @@ def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, .. _, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args) chunked_data_iter = iter(chunked_data) out: list[Dataset | DataArray] = [] - for obj, ds in zip(objects, datasets): + for obj, ds in zip(objects, datasets, strict=True): for k, v in ds._variables.items(): if v.chunks is not None: ds._variables[k] = v.copy(data=next(chunked_data_iter)) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 182cf8a23a1..1133d8cc373 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -400,7 +400,9 @@ def process_subset_opt(opt, subset): equals[k] = False # computed variables are not to be re-computed # again in the future - for ds, v in zip(datasets[1:], computed): + for ds, v in zip( + datasets[1:], computed, strict=False + ): ds.variables[k].data = v.data break else: @@ -583,7 +585,7 @@ def ensure_common_dims(vars, concat_dim_lengths): common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims)) if dim_name not in common_dims: common_dims = (dim_name,) + common_dims - for var, dim_len in zip(vars, concat_dim_lengths): + for var, dim_len in zip(vars, concat_dim_lengths, strict=True): if var.dims != common_dims: common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims) var = var.set_dims(common_dims, common_shape) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 3b852b962bf..8840ad7f8c3 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -877,7 +877,7 @@ def __delitem__(self, key: Hashable) -> None: assert_no_index_corrupted(self._data.xindexes, {key}) del self._data._coords[key] - if self._data._indexes is not None and key in self._data._indexes: + if key in self._data._indexes: del self._data._indexes[key] def _ipython_key_completions_(self): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1f0544c1041..4b6185edf38 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -12,6 +12,7 @@ ) from functools import partial from os import PathLike +from types import EllipsisType from typing import ( TYPE_CHECKING, Any, @@ -102,6 +103,7 @@ Dims, ErrorOptions, ErrorOptionsWithWarn, + GroupInput, InterpOptions, PadModeOptions, PadReflectOptions, @@ -109,6 +111,7 @@ QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, + ResampleCompatible, Self, SideOptions, T_ChunkDimFreq, @@ -123,7 +126,7 @@ def _check_coords_dims(shape, coords, dim): - sizes = dict(zip(dim, shape)) + sizes = dict(zip(dim, shape, strict=True)) for k, v in coords.items(): if any(d not in dim for d in v.dims): raise ValueError( @@ -172,7 +175,7 @@ def _infer_coords_and_dims( if utils.is_dict_like(coords): dims = list(coords.keys()) else: - for n, (dim, coord) in enumerate(zip(dims, coords)): + for n, (dim, coord) in enumerate(zip(dims, coords, strict=True)): coord = as_variable( coord, name=dims[n], auto_convert=False ).to_index_variable() @@ -199,7 +202,7 @@ def _infer_coords_and_dims( if new_coords[k].dims == (k,): new_coords[k] = new_coords[k].to_index_variable() elif coords is not None: - for dim, coord in zip(dims_tuple, coords): + for dim, coord in zip(dims_tuple, coords, strict=True): var = as_variable(coord, name=dim, auto_convert=False) var.dims = (dim,) new_coords[dim] = var.to_index_variable() @@ -251,14 +254,14 @@ def __getitem__(self, key) -> T_DataArray: if not utils.is_dict_like(key): # expand the indexer so we can handle Ellipsis labels = indexing.expanded_indexer(key, self.data_array.ndim) - key = dict(zip(self.data_array.dims, labels)) + key = dict(zip(self.data_array.dims, labels, strict=True)) return self.data_array.sel(key) def __setitem__(self, key, value) -> None: if not utils.is_dict_like(key): # expand the indexer so we can handle Ellipsis labels = indexing.expanded_indexer(key, self.data_array.ndim) - key = dict(zip(self.data_array.dims, labels)) + key = dict(zip(self.data_array.dims, labels, strict=True)) dim_indexers = map_index_queries(self.data_array, key).dim_indexers self.data_array[dim_indexers] = value @@ -438,7 +441,7 @@ def __init__( name: Hashable | None = None, attrs: Mapping | None = None, # internal parameters - indexes: Mapping[Any, Index] | None = None, + indexes: Mapping[Hashable, Index] | None = None, fastpath: bool = False, ) -> None: if fastpath: @@ -486,7 +489,7 @@ def __init__( assert isinstance(coords, dict) self._coords = coords self._name = name - self._indexes = indexes # type: ignore[assignment] + self._indexes = dict(indexes) self._close = None @@ -536,7 +539,7 @@ def _replace_maybe_drop_dims( indexes = self._indexes elif variable.dims == self.dims: # Shape has changed (e.g. from reduce(..., keepdims=True) - new_sizes = dict(zip(self.dims, variable.shape)) + new_sizes = dict(zip(self.dims, variable.shape, strict=True)) coords = { k: v for k, v in self._coords.items() @@ -875,7 +878,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: if utils.is_dict_like(key): return key key = indexing.expanded_indexer(key, self.ndim) - return dict(zip(self.dims, key)) + return dict(zip(self.dims, key, strict=True)) def _getitem_coord(self, key: Any) -> Self: from xarray.core.dataset import _get_virtual_variable @@ -883,7 +886,7 @@ def _getitem_coord(self, key: Any) -> Self: try: var = self._coords[key] except KeyError: - dim_sizes = dict(zip(self.dims, self.shape)) + dim_sizes = dict(zip(self.dims, self.shape, strict=True)) _, key, var = _get_virtual_variable(self._coords, key, dim_sizes) return self._replace_maybe_drop_dims(var, name=key) @@ -1436,7 +1439,7 @@ def chunk( "It will raise an error in the future. Instead use a dict with dimension names as keys.", category=DeprecationWarning, ) - chunk_mapping = dict(zip(self.dims, chunks)) + chunk_mapping = dict(zip(self.dims, chunks, strict=True)) else: chunk_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") @@ -2841,7 +2844,7 @@ def stack( dim: Mapping[Any, Sequence[Hashable]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, - **dim_kwargs: Sequence[Hashable], + **dim_kwargs: Sequence[Hashable | EllipsisType], ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -3919,7 +3922,7 @@ def to_dataframe( ds = self._to_dataset_whole(name=unique_name) if dim_order is None: - ordered_dims = dict(zip(self.dims, self.shape)) + ordered_dims = dict(zip(self.dims, self.shape, strict=True)) else: ordered_dims = ds._normalize_dim_order(dim_order=dim_order) @@ -4141,7 +4144,7 @@ def to_netcdf( # No problems with the name - so we're fine! dataset = self.to_dataset() - return to_netcdf( # type: ignore # mypy cannot resolve the overloads:( + return to_netcdf( # type: ignore[return-value] # mypy cannot resolve the overloads:( dataset, path, mode=mode, @@ -6706,9 +6709,7 @@ def interp_calendar( @_deprecate_positional_args("v2024.07.0") def groupby( self, - group: ( - Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None - ) = None, + group: GroupInput = None, *, squeeze: Literal[False] = False, restore_coord_dims: bool = False, @@ -6718,7 +6719,7 @@ def groupby( Parameters ---------- - group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper + group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper Array whose unique values should be used to group this array. If a Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, must map an existing variable name to a :py:class:`Grouper` instance. @@ -6769,6 +6770,52 @@ def groupby( Coordinates: * dayofyear (dayofyear) int64 3kB 1 2 3 4 5 6 7 ... 361 362 363 364 365 366 + >>> da = xr.DataArray( + ... data=np.arange(12).reshape((4, 3)), + ... dims=("x", "y"), + ... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, + ... ) + + Grouping by a single variable is easy + + >>> da.groupby("letters") + + + Execute a reduction + + >>> da.groupby("letters").sum() + Size: 48B + array([[ 9., 11., 13.], + [ 9., 11., 13.]]) + Coordinates: + * letters (letters) object 16B 'a' 'b' + Dimensions without coordinates: y + + Grouping by multiple variables + + >>> da.groupby(["letters", "x"]) + + + Use Grouper objects to express more complicated GroupBy operations + + >>> from xarray.groupers import BinGrouper, UniqueGrouper + >>> + >>> da.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() + Size: 96B + array([[[ 0., 1., 2.], + [nan, nan, nan]], + + [[nan, nan, nan], + [ 3., 4., 5.]]]) + Coordinates: + * x_bins (x_bins) object 16B (5, 15] (15, 25] + * letters (letters) object 16B 'a' 'b' + Dimensions without coordinates: y + + See Also -------- :ref:`groupby` @@ -6790,32 +6837,12 @@ def groupby( """ from xarray.core.groupby import ( DataArrayGroupBy, - ResolvedGrouper, + _parse_group_and_groupers, _validate_groupby_squeeze, ) - from xarray.groupers import UniqueGrouper _validate_groupby_squeeze(squeeze) - - if isinstance(group, Mapping): - groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore - group = None - - rgroupers: tuple[ResolvedGrouper, ...] - if group is not None: - if groupers: - raise ValueError( - "Providing a combination of `group` and **groupers is not supported." - ) - rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),) - else: - if not groupers: - raise ValueError("Either `group` or `**groupers` must be provided.") - rgroupers = tuple( - ResolvedGrouper(grouper, group, self) - for group, grouper in groupers.items() - ) - + rgroupers = _parse_group_and_groupers(self, group, groupers) return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @_deprecate_positional_args("v2024.07.0") @@ -7243,7 +7270,7 @@ def coarsen( @_deprecate_positional_args("v2024.07.0") def resample( self, - indexer: Mapping[Hashable, str | Resampler] | None = None, + indexer: Mapping[Hashable, ResampleCompatible | Resampler] | None = None, *, skipna: bool | None = None, closed: SideOptions | None = None, @@ -7251,7 +7278,7 @@ def resample( offset: pd.Timedelta | datetime.timedelta | str | None = None, origin: str | DatetimeLike = "start_day", restore_coord_dims: bool | None = None, - **indexer_kwargs: str | Resampler, + **indexer_kwargs: ResampleCompatible | Resampler, ) -> DataArrayResample: """Returns a Resample object for performing resampling operations. @@ -7262,7 +7289,7 @@ def resample( Parameters ---------- - indexer : Mapping of Hashable to str, optional + indexer : Mapping of Hashable to str, datetime.timedelta, pd.Timedelta, pd.DateOffset, or Resampler, optional Mapping from the dimension name to resample frequency [1]_. The dimension must be datetime-like. skipna : bool, optional @@ -7286,7 +7313,7 @@ def resample( restore_coord_dims : bool, optional If True, also restore the dimension order of multi-dimensional coordinates. - **indexer_kwargs : str + **indexer_kwargs : str, datetime.timedelta, pd.Timedelta, pd.DateOffset, or Resampler The keyword arguments form of ``indexer``. One of indexer or indexer_kwargs must be provided. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e14176f1589..08885e3cd8d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -23,6 +23,7 @@ from numbers import Number from operator import methodcaller from os import PathLike +from types import EllipsisType from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast, overload import numpy as np @@ -154,6 +155,7 @@ DsCompatible, ErrorOptions, ErrorOptionsWithWarn, + GroupInput, InterpOptions, JoinOptions, PadModeOptions, @@ -161,6 +163,7 @@ QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, + ResampleCompatible, SideOptions, T_ChunkDimFreq, T_DatasetPadConstantValues, @@ -240,13 +243,13 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): # Determine the explicit requested chunks. preferred_chunks = var.encoding.get("preferred_chunks", {}) preferred_chunk_shape = tuple( - preferred_chunks.get(dim, size) for dim, size in zip(dims, shape) + preferred_chunks.get(dim, size) for dim, size in zip(dims, shape, strict=True) ) if isinstance(chunks, Number) or (chunks == "auto"): chunks = dict.fromkeys(dims, chunks) chunk_shape = tuple( chunks.get(dim, None) or preferred_chunk_sizes - for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape) + for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape, strict=True) ) chunk_shape = chunkmanager.normalize_chunks( @@ -256,7 +259,7 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): # Warn where requested chunks break preferred chunks, provided that the variable # contains data. if var.size: - for dim, size, chunk_sizes in zip(dims, shape, chunk_shape): + for dim, size, chunk_sizes in zip(dims, shape, chunk_shape, strict=True): try: preferred_chunk_sizes = preferred_chunks[dim] except KeyError: @@ -282,7 +285,7 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): "degrade performance. Instead, consider rechunking after loading." ) - return dict(zip(dims, chunk_shape)) + return dict(zip(dims, chunk_shape, strict=True)) def _maybe_chunk( @@ -868,7 +871,7 @@ def load(self, **kwargs) -> Self: *lazy_data.values(), **kwargs ) - for k, data in zip(lazy_data, evaluated_data): + for k, data in zip(lazy_data, evaluated_data, strict=False): self.variables[k].data = data # load everything else sequentially @@ -1051,7 +1054,7 @@ def _persist_inplace(self, **kwargs) -> Self: # evaluate all the dask arrays simultaneously evaluated_data = dask.persist(*lazy_data.values(), **kwargs) - for k, data in zip(lazy_data, evaluated_data): + for k, data in zip(lazy_data, evaluated_data, strict=False): self.variables[k].data = data return self @@ -1651,11 +1654,13 @@ def __setitem__( f"setting ({len(value)})" ) if isinstance(value, Dataset): - self.update(dict(zip(keylist, value.data_vars.values()))) + self.update( + dict(zip(keylist, value.data_vars.values(), strict=True)) + ) elif isinstance(value, DataArray): raise ValueError("Cannot assign single DataArray to multiple keys") else: - self.update(dict(zip(keylist, value))) + self.update(dict(zip(keylist, value, strict=True))) else: raise ValueError(f"Unsupported key-type {type(key)}") @@ -2330,7 +2335,7 @@ def to_netcdf( encoding = {} from xarray.backends.api import to_netcdf - return to_netcdf( # type: ignore # mypy cannot resolve the overloads:( + return to_netcdf( # type: ignore[return-value] # mypy cannot resolve the overloads:( self, path, mode=mode, @@ -3047,7 +3052,7 @@ def isel( coord_names.remove(name) continue variables[name] = var - dims.update(zip(var.dims, var.shape)) + dims.update(zip(var.dims, var.shape, strict=True)) return self._construct_direct( variables=variables, @@ -4271,7 +4276,7 @@ def _rename_indexes( new_index_vars = new_index.create_variables( { new: self._variables[old] - for old, new in zip(coord_names, new_coord_names) + for old, new in zip(coord_names, new_coord_names, strict=True) } ) variables.update(new_index_vars) @@ -4778,9 +4783,9 @@ def expand_dims( raise ValueError("axis should not contain duplicate values") # We need to sort them to make sure `axis` equals to the # axis positions of the result array. - zip_axis_dim = sorted(zip(axis_pos, dim.items())) + zip_axis_dim = sorted(zip(axis_pos, dim.items(), strict=True)) - all_dims = list(zip(v.dims, v.shape)) + all_dims = list(zip(v.dims, v.shape, strict=True)) for d, c in zip_axis_dim: all_dims.insert(d, c) variables[k] = v.set_dims(dict(all_dims)) @@ -5304,7 +5309,7 @@ def _get_stack_index( def _stack_once( self, - dims: Sequence[Hashable | ellipsis], + dims: Sequence[Hashable | EllipsisType], new_dim: Hashable, index_cls: type[Index], create_index: bool | None = True, @@ -5364,10 +5369,10 @@ def _stack_once( @partial(deprecate_dims, old_name="dimensions") def stack( self, - dim: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None, + dim: Mapping[Any, Sequence[Hashable | EllipsisType]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, - **dim_kwargs: Sequence[Hashable | ellipsis], + **dim_kwargs: Sequence[Hashable | EllipsisType], ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -7323,7 +7328,7 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]): ] index = self.coords.to_index([*ordered_dims]) broadcasted_df = pd.DataFrame( - dict(zip(non_extension_array_columns, data)), index=index + dict(zip(non_extension_array_columns, data, strict=True)), index=index ) for extension_array_column in extension_array_columns: extension_array = self.variables[extension_array_column].data.array @@ -7498,10 +7503,10 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: if isinstance(idx, pd.MultiIndex): dims = tuple( - name if name is not None else "level_%i" % n + name if name is not None else "level_%i" % n # type: ignore[redundant-expr] for n, name in enumerate(idx.names) ) - for dim, lev in zip(dims, idx.levels): + for dim, lev in zip(dims, idx.levels, strict=True): xr_idx = PandasIndex(lev, dim) indexes[dim] = xr_idx index_vars.update(xr_idx.create_variables()) @@ -9633,7 +9638,7 @@ def argmin(self, dim: Hashable | None = None, **kwargs) -> Self: ): # Return int index if single dimension is passed, and is not part of a # sequence - argmin_func = getattr(duck_array_ops, "argmin") + argmin_func = duck_array_ops.argmin return self.reduce( argmin_func, dim=None if dim is None else [dim], **kwargs ) @@ -9726,7 +9731,7 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: ): # Return int index if single dimension is passed, and is not part of a # sequence - argmax_func = getattr(duck_array_ops, "argmax") + argmax_func = duck_array_ops.argmax return self.reduce( argmax_func, dim=None if dim is None else [dim], **kwargs ) @@ -9747,7 +9752,7 @@ def eval( Calculate an expression supplied as a string in the context of the dataset. This is currently experimental; the API may change particularly around - assignments, which currently returnn a ``Dataset`` with the additional variable. + assignments, which currently return a ``Dataset`` with the additional variable. Currently only the ``python`` engine is supported, which has the same performance as executing in python. @@ -10025,7 +10030,7 @@ def curvefit( f"dimensions {preserved_dims}." ) for param, (lb, ub) in bounds.items(): - for label, bound in zip(("Lower", "Upper"), (lb, ub)): + for label, bound in zip(("Lower", "Upper"), (lb, ub), strict=True): if isinstance(bound, DataArray): unexpected = set(bound.dims) - set(preserved_dims) if unexpected: @@ -10331,9 +10336,7 @@ def interp_calendar( @_deprecate_positional_args("v2024.07.0") def groupby( self, - group: ( - Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None - ) = None, + group: GroupInput = None, *, squeeze: Literal[False] = False, restore_coord_dims: bool = False, @@ -10343,7 +10346,7 @@ def groupby( Parameters ---------- - group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper + group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper Array whose unique values should be used to group this array. If a Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, must map an existing variable name to a :py:class:`Grouper` instance. @@ -10365,6 +10368,51 @@ def groupby( A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. + Examples + -------- + >>> ds = xr.Dataset( + ... {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))}, + ... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, + ... ) + + Grouping by a single variable is easy + + >>> ds.groupby("letters") + + + Execute a reduction + + >>> ds.groupby("letters").sum() + Size: 64B + Dimensions: (letters: 2, y: 3) + Coordinates: + * letters (letters) object 16B 'a' 'b' + Dimensions without coordinates: y + Data variables: + foo (letters, y) float64 48B 9.0 11.0 13.0 9.0 11.0 13.0 + + Grouping by multiple variables + + >>> ds.groupby(["letters", "x"]) + + + Use Grouper objects to express more complicated GroupBy operations + + >>> from xarray.groupers import BinGrouper, UniqueGrouper + >>> + >>> ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() + Size: 128B + Dimensions: (y: 3, x_bins: 2, letters: 2) + Coordinates: + * x_bins (x_bins) object 16B (5, 15] (15, 25] + * letters (letters) object 16B 'a' 'b' + Dimensions without coordinates: y + Data variables: + foo (y, x_bins, letters) float64 96B 0.0 nan nan 3.0 ... nan nan 5.0 + See Also -------- :ref:`groupby` @@ -10386,31 +10434,12 @@ def groupby( """ from xarray.core.groupby import ( DatasetGroupBy, - ResolvedGrouper, + _parse_group_and_groupers, _validate_groupby_squeeze, ) - from xarray.groupers import UniqueGrouper _validate_groupby_squeeze(squeeze) - - if isinstance(group, Mapping): - groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore - group = None - - rgroupers: tuple[ResolvedGrouper, ...] - if group is not None: - if groupers: - raise ValueError( - "Providing a combination of `group` and **groupers is not supported." - ) - rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),) - else: - if not groupers: - raise ValueError("Either `group` or `**groupers` must be provided.") - rgroupers = tuple( - ResolvedGrouper(grouper, group, self) - for group, grouper in groupers.items() - ) + rgroupers = _parse_group_and_groupers(self, group, groupers) return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @@ -10684,7 +10713,7 @@ def coarsen( @_deprecate_positional_args("v2024.07.0") def resample( self, - indexer: Mapping[Any, str | Resampler] | None = None, + indexer: Mapping[Any, ResampleCompatible | Resampler] | None = None, *, skipna: bool | None = None, closed: SideOptions | None = None, @@ -10692,7 +10721,7 @@ def resample( offset: pd.Timedelta | datetime.timedelta | str | None = None, origin: str | DatetimeLike = "start_day", restore_coord_dims: bool | None = None, - **indexer_kwargs: str | Resampler, + **indexer_kwargs: ResampleCompatible | Resampler, ) -> DatasetResample: """Returns a Resample object for performing resampling operations. @@ -10703,7 +10732,7 @@ def resample( Parameters ---------- - indexer : Mapping of Hashable to str, optional + indexer : Mapping of Hashable to str, datetime.timedelta, pd.Timedelta, pd.DateOffset, or Resampler, optional Mapping from the dimension name to resample frequency [1]_. The dimension must be datetime-like. skipna : bool, optional @@ -10727,7 +10756,7 @@ def resample( restore_coord_dims : bool, optional If True, also restore the dimension order of multi-dimensional coordinates. - **indexer_kwargs : str + **indexer_kwargs : str, datetime.timedelta, pd.Timedelta, pd.DateOffset, or Resampler The keyword arguments form of ``indexer``. One of indexer or indexer_kwargs must be provided. diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1b8a5ffbf38..b12d861624a 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -11,7 +11,7 @@ Mapping, ) from html import escape -from typing import TYPE_CHECKING, Any, Generic, Literal, NoReturn, Union, overload +from typing import TYPE_CHECKING, Any, Literal, NoReturn, Union, overload from xarray.core import utils from xarray.core.alignment import align @@ -37,7 +37,7 @@ from xarray.core.indexes import Index, Indexes from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS -from xarray.core.treenode import NamedNode, NodePath, Tree +from xarray.core.treenode import NamedNode, NodePath from xarray.core.utils import ( Default, Frozen, @@ -91,17 +91,13 @@ def _collect_data_and_coord_variables( return data_variables, coord_variables -def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: - if isinstance(data, DataArray): - ds = data.to_dataset() - elif isinstance(data, Dataset): +def _to_new_dataset(data: Dataset | None) -> Dataset: + if isinstance(data, Dataset): ds = data.copy(deep=False) elif data is None: ds = Dataset() else: - raise TypeError( - f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}" - ) + raise TypeError(f"data object is not an xarray.Dataset, dict, or None: {data}") return ds @@ -234,8 +230,7 @@ def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[overload-overlap ... @overload - def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[overload-overlap] - ... + def __getitem__(self, key: Hashable) -> DataArray: ... # See: https://github.com/pydata/xarray/issues/8855 @overload @@ -370,8 +365,7 @@ class DataTree( MappedDataWithCoords, DataTreeArithmeticMixin, TreeAttrAccessMixin, - Generic[Tree], - Mapping, + Mapping[str, "DataArray | DataTree"], ): """ A tree-like hierarchical collection of xarray objects. @@ -423,8 +417,7 @@ class DataTree( def __init__( self, - data: Dataset | DataArray | None = None, - parent: DataTree | None = None, + data: Dataset | None = None, children: Mapping[str, DataTree] | None = None, name: str | None = None, ): @@ -437,11 +430,8 @@ def __init__( Parameters ---------- - data : Dataset, DataArray, or None, optional - Data to store under the .ds attribute of this node. DataArrays will - be promoted to Datasets. Default is None. - parent : DataTree, optional - Parent node to this node. Default is None. + data : Dataset, optional + Data to store under the .ds attribute of this node. children : Mapping[str, DataTree], optional Any child nodes of this node. Default is None. name : str, optional @@ -459,9 +449,10 @@ def __init__( children = {} super().__init__(name=name) - self._set_node_data(_coerce_to_dataset(data)) - self.parent = parent - self.children = children + self._set_node_data(_to_new_dataset(data)) + + # shallow copy to avoid modifying arguments in-place (see GH issue #9196) + self.children = {name: child.copy() for name, child in children.items()} def _set_node_data(self, ds: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(ds) @@ -481,7 +472,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: ) path = str(NodePath(parent.path) / name) node_ds = self.to_dataset(inherited=False) - parent_ds = parent._to_dataset_view(rebuild_dims=False) + parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) _check_alignment(path, node_ds, parent_ds, self.children) @property @@ -498,41 +489,46 @@ def _dims(self) -> ChainMap[Hashable, int]: def _indexes(self) -> ChainMap[Hashable, Index]: return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents)) - @property - def parent(self: DataTree) -> DataTree | None: - """Parent of this node.""" - return self._parent - - @parent.setter - def parent(self: DataTree, new_parent: DataTree) -> None: - if new_parent and self.name is None: - raise ValueError("Cannot set an unnamed node as a child of another node") - self._set_parent(new_parent, self.name) - - def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView: + def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView: + coord_vars = self._coord_variables if inherited else self._node_coord_variables variables = dict(self._data_variables) - variables |= self._coord_variables + variables |= coord_vars if rebuild_dims: dims = calculate_dimensions(variables) - else: - # Note: rebuild_dims=False can create technically invalid Dataset - # objects because it may not contain all dimensions on its direct - # member variables, e.g., consider: - # tree = DataTree.from_dict( - # { - # "/": xr.Dataset({"a": (("x",), [1, 2])}), # x has size 2 - # "/b/c": xr.Dataset({"d": (("x",), [3])}), # x has size1 - # } - # ) - # However, they are fine for internal use cases, for align() or - # building a repr(). + elif inherited: + # Note: rebuild_dims=False with inherited=True can create + # technically invalid Dataset objects because it still includes + # dimensions that are only defined on parent data variables (i.e. not present on any parent coordinate variables), e.g., + # consider: + # >>> tree = DataTree.from_dict( + # ... { + # ... "/": xr.Dataset({"foo": ("x", [1, 2])}), # x has size 2 + # ... "/b": xr.Dataset(), + # ... } + # ... ) + # >>> ds = tree["b"]._to_dataset_view(rebuild_dims=False, inherited=True) + # >>> ds + # Size: 0B + # Dimensions: (x: 2) + # Dimensions without coordinates: x + # Data variables: + # *empty* + # + # Notice the "x" dimension is still defined, even though there are no + # variables or coordinates. + # Normally this is not supposed to be possible in xarray's data model, but here it is useful internally for use cases where we + # want to inherit everything from parents nodes, e.g., for align() + # and repr(). + # The user should never be able to see this dimension via public API. dims = dict(self._dims) + else: + dims = dict(self._node_dims) return DatasetView._constructor( variables=variables, coord_names=set(self._coord_variables), dims=dims, attrs=self._attrs, - indexes=dict(self._indexes), + indexes=dict(self._indexes if inherited else self._node_indexes), encoding=self._encoding, close=None, ) @@ -551,11 +547,11 @@ def ds(self) -> DatasetView: -------- DataTree.to_dataset """ - return self._to_dataset_view(rebuild_dims=True) + return self._to_dataset_view(rebuild_dims=True, inherited=True) @ds.setter - def ds(self, data: Dataset | DataArray | None = None) -> None: - ds = _coerce_to_dataset(data) + def ds(self, data: Dataset | None = None) -> None: + ds = _to_new_dataset(data) self._replace_node(ds) def to_dataset(self, inherited: bool = True) -> Dataset: @@ -719,8 +715,8 @@ def __contains__(self, key: object) -> bool: def __bool__(self) -> bool: return bool(self._data_variables) or bool(self._children) - def __iter__(self) -> Iterator[Hashable]: - return itertools.chain(self._data_variables, self._children) + def __iter__(self) -> Iterator[str]: + return itertools.chain(self._data_variables, self._children) # type: ignore[arg-type] def __array__(self, dtype=None, copy=None): raise TypeError( @@ -758,7 +754,7 @@ def _replace_node( raise ValueError(f"node already contains a variable named {child_name}") parent_ds = ( - self.parent._to_dataset_view(rebuild_dims=False) + self.parent._to_dataset_view(rebuild_dims=False, inherited=True) if self.parent is not None else None ) @@ -819,8 +815,10 @@ def _copy_node( deep: bool = False, ) -> DataTree: """Copy just one node of a tree""" - data = self.ds.copy(deep=deep) - new_node: DataTree = DataTree(data, name=self.name) + data = self._to_dataset_view(rebuild_dims=False, inherited=False) + if deep: + data = data.copy(deep=True) + new_node = DataTree(data, name=self.name) return new_node def __copy__(self: DataTree) -> DataTree: @@ -897,7 +895,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None: # create and assign a shallow copy here so as not to alter original name of node in grafted tree new_node = val.copy(deep=False) new_node.name = key - new_node.parent = self + new_node._set_parent(new_parent=self, child_name=key) else: if not isinstance(val, DataArray | Variable): # accommodate other types that can be coerced into Variables @@ -929,6 +927,24 @@ def __setitem__( else: raise ValueError("Invalid format for key") + def __delitem__(self, key: str) -> None: + """Remove a variable or child node from this datatree node.""" + if key in self.children: + super().__delitem__(key) + + elif key in self._node_coord_variables: + if key in self._node_indexes: + del self._node_indexes[key] + del self._node_coord_variables[key] + self._node_dims = calculate_dimensions(self.variables) + + elif key in self._data_variables: + del self._data_variables[key] + self._node_dims = calculate_dimensions(self.variables) + + else: + raise KeyError(key) + @overload def update(self, other: Dataset) -> None: ... @@ -1064,7 +1080,8 @@ def drop_nodes( @classmethod def from_dict( cls, - d: Mapping[str, Dataset | DataArray | DataTree | None], + d: Mapping[str, Dataset | DataTree | None], + /, name: str | None = None, ) -> DataTree: """ @@ -1073,10 +1090,10 @@ def from_dict( Parameters ---------- d : dict-like - A mapping from path names to xarray.Dataset, xarray.DataArray, or DataTree objects. + A mapping from path names to xarray.Dataset or DataTree objects. - Path names are to be given as unix-like path. If path names containing more than one part are given, new - tree nodes will be constructed as necessary. + Path names are to be given as unix-like path. If path names containing more than one + part are given, new tree nodes will be constructed as necessary. To assign data to the root node of the tree use "/" as the path. name : Hashable | None, optional @@ -1096,9 +1113,12 @@ def from_dict( root_data = d_cast.pop("/", None) if isinstance(root_data, DataTree): obj = root_data.copy() - obj.orphan() + elif root_data is None or isinstance(root_data, Dataset): + obj = cls(name=name, data=root_data, children=None) else: - obj = cls(name=name, data=root_data, parent=None, children=None) + raise TypeError( + f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}' + ) def depth(item) -> int: pathstr, _ = item @@ -1112,9 +1132,10 @@ def depth(item) -> int: node_name = NodePath(path).name if isinstance(data, DataTree): new_node = data.copy() - new_node.orphan() - else: + elif isinstance(data, Dataset) or data is None: new_node = cls(name=node_name, data=data) + else: + raise TypeError(f"invalid values: {data}") obj._set_item( path, new_node, @@ -1244,7 +1265,7 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool: return all( [ node.ds.equals(other_node.ds) - for node, other_node in zip(self.subtree, other.subtree) + for node, other_node in zip(self.subtree, other.subtree, strict=True) ] ) @@ -1274,7 +1295,7 @@ def identical(self, other: DataTree, from_root=True) -> bool: return all( node.ds.identical(other_node.ds) - for node, other_node in zip(self.subtree, other.subtree) + for node, other_node in zip(self.subtree, other.subtree, strict=True) ) def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: @@ -1521,7 +1542,7 @@ def to_netcdf( mode : {"w", "a"}, default: "w" Write ('w') or append ('a') mode. If mode='w', any existing file at this location will be overwritten. If mode='a', existing variables - will be overwritten. Only appies to the root group. + will be overwritten. Only applies to the root group. encoding : dict, optional Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 17630466016..1a581629ab8 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -157,6 +157,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: first_tree.subtree, *args_as_tree_length_iterables, *list(kwargs_as_tree_length_iterables.values()), + strict=False, ): node_args_as_datasetviews = [ a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args] @@ -168,6 +169,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: v.ds if isinstance(v, DataTree) else v for v in all_node_args[n_args:] ], + strict=True, ) ) func_with_error_context = _handle_errors_with_path_context( diff --git a/xarray/core/datatree_ops.py b/xarray/core/datatree_ops.py index bc64b44ae1e..9e87cda191c 100644 --- a/xarray/core/datatree_ops.py +++ b/xarray/core/datatree_ops.py @@ -214,7 +214,7 @@ def method_name(self, *args, **kwargs): new_method_docstring = insert_doc_addendum( orig_method_docstring, _MAPPED_DOCSTRING_ADDENDUM ) - setattr(target_cls_dict[method_name], "__doc__", new_method_docstring) + target_cls_dict[method_name].__doc__ = new_method_docstring def insert_doc_addendum(docstring: str | None, addendum: str) -> str | None: @@ -224,7 +224,7 @@ def insert_doc_addendum(docstring: str | None, addendum: str) -> str | None: Dataset directly as well as the mixins: DataWithCoords, DatasetAggregations, and DatasetOpsMixin. The majority of the docstrings fall into a parseable pattern. Those that - don't, just have the addendum appeneded after. None values are returned. + don't, just have the addendum appended after. None values are returned. """ if docstring is None: diff --git a/xarray/core/datatree_render.py b/xarray/core/datatree_render.py index 98cb4f91495..47e0358588d 100644 --- a/xarray/core/datatree_render.py +++ b/xarray/core/datatree_render.py @@ -51,11 +51,16 @@ def __init__(self): >>> from xarray.core.datatree import DataTree >>> from xarray.core.datatree_render import RenderDataTree - >>> root = DataTree(name="root") - >>> s0 = DataTree(name="sub0", parent=root) - >>> s0b = DataTree(name="sub0B", parent=s0) - >>> s0a = DataTree(name="sub0A", parent=s0) - >>> s1 = DataTree(name="sub1", parent=root) + >>> root = DataTree.from_dict( + ... { + ... "/": None, + ... "/sub0": None, + ... "/sub0/sub0B": None, + ... "/sub0/sub0A": None, + ... "/sub1": None, + ... }, + ... name="root", + ... ) >>> print(RenderDataTree(root)) Group: / @@ -98,11 +103,16 @@ def __init__( >>> from xarray import Dataset >>> from xarray.core.datatree import DataTree >>> from xarray.core.datatree_render import RenderDataTree - >>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1})) - >>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3})) - >>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4})) - >>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6})) - >>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7})) + >>> root = DataTree.from_dict( + ... { + ... "/": Dataset({"a": 0, "b": 1}), + ... "/sub0": Dataset({"c": 2, "d": 3}), + ... "/sub0/sub0B": Dataset({"e": 4}), + ... "/sub0/sub0A": Dataset({"f": 5, "g": 6}), + ... "/sub1": Dataset({"h": 7}), + ... }, + ... name="root", + ... ) # Simple one line: @@ -208,17 +218,16 @@ def by_attr(self, attrname: str = "name") -> str: >>> from xarray import Dataset >>> from xarray.core.datatree import DataTree >>> from xarray.core.datatree_render import RenderDataTree - >>> root = DataTree(name="root") - >>> s0 = DataTree(name="sub0", parent=root) - >>> s0b = DataTree( - ... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109}) + >>> root = DataTree.from_dict( + ... { + ... "/sub0/sub0B": Dataset({"foo": 4, "bar": 109}), + ... "/sub0/sub0A": None, + ... "/sub1/sub1A": None, + ... "/sub1/sub1B": Dataset({"bar": 8}), + ... "/sub1/sub1C/sub1Ca": None, + ... }, + ... name="root", ... ) - >>> s0a = DataTree(name="sub0A", parent=s0) - >>> s1 = DataTree(name="sub1", parent=root) - >>> s1a = DataTree(name="sub1A", parent=s1) - >>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8})) - >>> s1c = DataTree(name="sub1C", parent=s1) - >>> s1ca = DataTree(name="sub1Ca", parent=s1c) >>> print(RenderDataTree(root).by_attr("name")) root ├── sub0 diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index b39f7628fd3..7464c1e8a89 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -64,7 +64,7 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: if isdtype(dtype, "real floating"): dtype_ = dtype fill_value = np.nan - elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.timedelta64): + elif np.issubdtype(dtype, np.timedelta64): # See https://github.com/numpy/numpy/issues/10685 # np.timedelta64 is a subclass of np.integer # Check np.timedelta64 before np.integer @@ -76,7 +76,7 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: elif isdtype(dtype, "complex floating"): dtype_ = dtype fill_value = np.nan + np.nan * 1j - elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64): + elif np.issubdtype(dtype, np.datetime64): dtype_ = dtype fill_value = np.datetime64("NaT") else: @@ -200,7 +200,7 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool: # numpy>=2 and pandas extensions arrays are implemented in # Xarray via the array API if not isinstance(kind, str) and not ( - isinstance(kind, tuple) and all(isinstance(k, str) for k in kind) + isinstance(kind, tuple) and all(isinstance(k, str) for k in kind) # type: ignore[redundant-expr] ): raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}") diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index ec78588c527..110f80f8f5f 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -474,7 +474,7 @@ def prefixes(length: int) -> list[str]: preformatted = [ pretty_print(f" {prefix} {name}", col_width) - for prefix, name in zip(prefixes(len(names)), names) + for prefix, name in zip(prefixes(len(names)), names, strict=True) ] head, *tail = preformatted @@ -862,7 +862,7 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs): temp = [ "\n".join([var_s, attr_s]) if attr_s else var_s - for var_s, attr_s in zip(temp, attrs_summary) + for var_s, attr_s in zip(temp, attrs_summary, strict=True) ] # TODO: It should be possible recursively use _diff_mapping_repr @@ -877,7 +877,9 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs): # ) # temp += [newdiff] - diff_items += [ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp)] + diff_items += [ + ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp, strict=True) + ] if diff_items: summary += [f"Differing {title.lower()}:"] + diff_items @@ -941,7 +943,7 @@ def diff_array_repr(a, b, compat): temp = [wrap_indent(short_array_repr(obj), start=" ") for obj in (a, b)] diff_data_repr = [ ab_side + "\n" + ab_data_repr - for ab_side, ab_data_repr in zip(("L", "R"), temp) + for ab_side, ab_data_repr in zip(("L", "R"), temp, strict=True) ] summary += ["Differing values:"] + diff_data_repr @@ -966,7 +968,7 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s # Walking nodes in "level-order" fashion means walking down from the root breadth-first. # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree # (which it is so long as children are stored in a tuple or list rather than in a set). - for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): + for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b), strict=True): path_a, path_b = node_a.path, node_b.path if require_names_equal and node_a.name != node_b.name: @@ -1013,7 +1015,7 @@ def diff_nodewise_summary(a: DataTree, b: DataTree, compat): compat_str = _compat_to_str(compat) summary = [] - for node_a, node_b in zip(a.subtree, b.subtree): + for node_a, node_b in zip(a.subtree, b.subtree, strict=True): a_ds, b_ds = node_a.ds, node_b.ds if not a_ds._all_compat(b_ds, compat): @@ -1051,7 +1053,10 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): def _single_node_repr(node: DataTree) -> str: """Information about this node, not including its relationships to other nodes.""" if node.has_data or node.has_attrs: - ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False)) + # TODO: change this to inherited=False, in order to clarify what is + # inherited? https://github.com/pydata/xarray/issues/9463 + node_view = node._to_dataset_view(rebuild_dims=False, inherited=True) + ds_info = "\n" + repr(node_view) else: ds_info = "" return f"Group: {node.path}{ds_info}" diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 24b290031eb..34c7a93bd7a 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -303,7 +303,7 @@ def _obj_repr(obj, header_components, sections): def array_repr(arr) -> str: - dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) + dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape, strict=True)) if hasattr(arr, "xindexes"): indexed_dims = arr.xindexes.dims else: @@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str: def datatree_node_repr(group_title: str, dt: DataTree) -> str: header_components = [f"
{escape(group_title)}
"] - ds = dt._to_dataset_view(rebuild_dims=False) + ds = dt._to_dataset_view(rebuild_dims=False, inherited=True) sections = [ children_section(dt.children), diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cc83b32adc8..a5e520b98b6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Literal, Union +from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast import numpy as np import pandas as pd @@ -54,7 +54,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import GroupIndex, GroupIndices, GroupKey + from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey from xarray.core.utils import Frozen from xarray.groupers import EncodedGroups, Grouper @@ -197,7 +197,7 @@ def __array__(self) -> np.ndarray: return np.arange(self.size) @property - def shape(self) -> tuple[int]: + def shape(self) -> tuple[int, ...]: return (self.size,) @property @@ -319,6 +319,51 @@ def __len__(self) -> int: return len(self.encoded.full_index) +def _parse_group_and_groupers( + obj: T_Xarray, group: GroupInput, groupers: dict[str, Grouper] +) -> tuple[ResolvedGrouper, ...]: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + from xarray.groupers import UniqueGrouper + + if group is not None and groupers: + raise ValueError( + "Providing a combination of `group` and **groupers is not supported." + ) + + if group is None and not groupers: + raise ValueError("Either `group` or `**groupers` must be provided.") + + if isinstance(group, np.ndarray | pd.Index): + raise TypeError( + f"`group` must be a DataArray. Received {type(group).__name__!r} instead" + ) + + if isinstance(group, Mapping): + grouper_mapping = either_dict_or_kwargs(group, groupers, "groupby") + group = None + + rgroupers: tuple[ResolvedGrouper, ...] + if isinstance(group, DataArray | Variable): + rgroupers = (ResolvedGrouper(UniqueGrouper(), group, obj),) + else: + if group is not None: + if TYPE_CHECKING: + assert isinstance(group, str | Sequence) + group_iter: Sequence[Hashable] = ( + (group,) if isinstance(group, str) else group + ) + grouper_mapping = {g: UniqueGrouper() for g in group_iter} + elif groupers: + grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers) + + rgroupers = tuple( + ResolvedGrouper(grouper, group, obj) + for group, grouper in grouper_mapping.items() + ) + return rgroupers + + def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the @@ -327,7 +372,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: # A future version could make squeeze kwarg only, but would face # backward-compat issues. if squeeze is not False: - raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.") + raise TypeError(f"`squeeze` must be False, but {squeeze!r} was supplied.") def _resolve_group( @@ -413,7 +458,7 @@ def factorize(self) -> EncodedGroups: ) # NaNs; as well as values outside the bins are coded by -1 # Restore these after the raveling - mask = functools.reduce(np.logical_or, [(code == -1) for code in broadcasted_codes]) # type: ignore + mask = functools.reduce(np.logical_or, [(code == -1) for code in broadcasted_codes]) # type: ignore[arg-type] _flatcodes[mask] = -1 midx = pd.MultiIndex.from_product( @@ -601,7 +646,11 @@ def groups(self) -> dict[GroupKey, GroupIndex]: # provided to mimic pandas.groupby if self._groups is None: self._groups = dict( - zip(self.encoded.unique_coord.data, self.encoded.group_indices) + zip( + self.encoded.unique_coord.data, + self.encoded.group_indices, + strict=True, + ) ) return self._groups @@ -615,7 +664,7 @@ def __len__(self) -> int: return self._len def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: - return zip(self.encoded.unique_coord.data, self._iter_grouped()) + return zip(self.encoded.unique_coord.data, self._iter_grouped(), strict=True) def __repr__(self) -> str: text = ( @@ -626,7 +675,7 @@ def __repr__(self) -> str: for grouper in self.groupers: coord = grouper.unique_coord labels = ", ".join(format_array_flat(coord, 30).split()) - text += f"\n\t{grouper.name!r}: {coord.size} groups with labels {labels}" + text += f"\n {grouper.name!r}: {coord.size} groups with labels {labels}" return text + ">" def _iter_grouped(self) -> Iterator[T_Xarray]: @@ -800,7 +849,7 @@ def _flox_reduce( obj = self._original_obj variables = ( {k: v.variable for k, v in obj.data_vars.items()} - if isinstance(obj, Dataset) + if isinstance(obj, Dataset) # type: ignore[redundant-expr] # seems to be a mypy bug else obj._coords ) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 80d15f8cde9..35870064db5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -644,7 +644,7 @@ def from_variables( # preserve wrapped pd.Index (if any) # accessing `.data` can load data from disk, so we only access if needed - data = getattr(var._data, "array") if hasattr(var._data, "array") else var.data + data = var._data.array if hasattr(var._data, "array") else var.data # multi-index level variable: get level index if isinstance(var._data, PandasMultiIndexingAdapter): level = var._data.level @@ -1024,14 +1024,16 @@ def stack( _check_dim_compat(variables, all_dims="different") level_indexes = [safe_cast_to_index(var) for var in variables.values()] - for name, idx in zip(variables, level_indexes): + for name, idx in zip(variables, level_indexes, strict=True): if isinstance(idx, pd.MultiIndex): raise ValueError( f"cannot create a multi-index along stacked dimension {dim!r} " f"from variable {name!r} that wraps a multi-index" ) - split_labels, levels = zip(*[lev.factorize() for lev in level_indexes]) + split_labels, levels = zip( + *[lev.factorize() for lev in level_indexes], strict=True + ) labels_mesh = np.meshgrid(*split_labels, indexing="ij") labels = [x.ravel() for x in labels_mesh] @@ -1051,7 +1053,7 @@ def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]: ) new_indexes: dict[Hashable, Index] = {} - for name, lev in zip(clean_index.names, clean_index.levels): + for name, lev in zip(clean_index.names, clean_index.levels, strict=True): idx = PandasIndex( lev.copy(), name, coord_dtype=self.level_coords_dtype[name] ) @@ -1258,7 +1260,9 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: else: levels = [self.index.names[i] for i in range(len(label))] indexer, new_index = self.index.get_loc_level(label, level=levels) - scalar_coord_values.update({k: v for k, v in zip(levels, label)}) + scalar_coord_values.update( + {k: v for k, v in zip(levels, label, strict=True)} + ) else: label_array = normalize_label(label) @@ -1360,7 +1364,8 @@ def rename(self, name_dict, dims_dict): new_dim = dims_dict.get(self.dim, self.dim) new_level_coords_dtype = { - k: v for k, v in zip(new_names, self.level_coords_dtype.values()) + k: v + for k, v in zip(new_names, self.level_coords_dtype.values(), strict=True) } return self._replace( index, dim=new_dim, level_coords_dtype=new_level_coords_dtype @@ -1802,7 +1807,7 @@ def check_variables(): def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str): # This function avoids the call to indexes.group_by_index - # which is really slow when repeatidly iterating through + # which is really slow when repeatedly iterating through # an array. However, it fails to return the correct ID for # multi-index arrays indexes_fast, coords = indexes._indexes, indexes._variables diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 1f5444b6baa..06b4b9a475f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -617,7 +617,7 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None): self.key = key shape: _Shape = () - for size, k in zip(self.array.shape, self.key.tuple): + for size, k in zip(self.array.shape, self.key.tuple, strict=True): if isinstance(k, slice): shape += (len(range(*k.indices(size))),) elif isinstance(k, np.ndarray): @@ -627,7 +627,7 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None): def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim)) full_key = [] - for size, k in zip(self.array.shape, self.key.tuple): + for size, k in zip(self.array.shape, self.key.tuple, strict=True): if isinstance(k, integer_types): full_key.append(k) else: @@ -907,7 +907,7 @@ def _outer_to_vectorized_indexer( n_dim = len([k for k in key if not isinstance(k, integer_types)]) i_dim = 0 new_key = [] - for k, size in zip(key, shape): + for k, size in zip(key, shape, strict=True): if isinstance(k, integer_types): new_key.append(np.array(k).reshape((1,) * n_dim)) else: # np.ndarray or slice @@ -1127,10 +1127,10 @@ def _decompose_vectorized_indexer( # convert negative indices indexer_elems = [ np.where(k < 0, k + s, k) if isinstance(k, np.ndarray) else k - for k, s in zip(indexer.tuple, shape) + for k, s in zip(indexer.tuple, shape, strict=True) ] - for k, s in zip(indexer_elems, shape): + for k, s in zip(indexer_elems, shape, strict=True): if isinstance(k, slice): # If it is a slice, then we will slice it as-is # (but make its step positive) in the backend, @@ -1207,7 +1207,7 @@ def _decompose_outer_indexer( assert isinstance(indexer, OuterIndexer | BasicIndexer) if indexing_support == IndexingSupport.VECTORIZED: - for k, s in zip(indexer.tuple, shape): + for k, s in zip(indexer.tuple, shape, strict=False): if isinstance(k, slice): # If it is a slice, then we will slice it as-is # (but make its step positive) in the backend, @@ -1222,7 +1222,7 @@ def _decompose_outer_indexer( # make indexer positive pos_indexer: list[np.ndarray | int | np.number] = [] - for k, s in zip(indexer.tuple, shape): + for k, s in zip(indexer.tuple, shape, strict=False): if isinstance(k, np.ndarray): pos_indexer.append(np.where(k < 0, k + s, k)) elif isinstance(k, integer_types) and k < 0: @@ -1244,7 +1244,7 @@ def _decompose_outer_indexer( ] array_index = np.argmax(np.array(gains)) if len(gains) > 0 else None - for i, (k, s) in enumerate(zip(indexer_elems, shape)): + for i, (k, s) in enumerate(zip(indexer_elems, shape, strict=False)): if isinstance(k, np.ndarray) and i != array_index: # np.ndarray key is converted to slice that covers the entire # entries of this key. @@ -1265,7 +1265,7 @@ def _decompose_outer_indexer( return (OuterIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) if indexing_support == IndexingSupport.OUTER: - for k, s in zip(indexer_elems, shape): + for k, s in zip(indexer_elems, shape, strict=False): if isinstance(k, slice): # slice: convert positive step slice for backend bk_slice, np_slice = _decompose_slice(k, s) @@ -1287,7 +1287,7 @@ def _decompose_outer_indexer( # basic indexer assert indexing_support == IndexingSupport.BASIC - for k, s in zip(indexer_elems, shape): + for k, s in zip(indexer_elems, shape, strict=False): if isinstance(k, np.ndarray): # np.ndarray key is converted to slice that covers the entire # entries of this key. @@ -1315,7 +1315,7 @@ def _arrayize_vectorized_indexer( n_dim = arrays[0].ndim if len(arrays) > 0 else 0 i_dim = 0 new_key = [] - for v, size in zip(indexer.tuple, shape): + for v, size in zip(indexer.tuple, shape, strict=True): if isinstance(v, np.ndarray): new_key.append(np.reshape(v, v.shape + (1,) * len(slices))) else: # slice @@ -1333,7 +1333,7 @@ def _chunked_array_with_chunks_hint( if len(chunks) < array.ndim: raise ValueError("not enough chunks in hint") new_chunks = [] - for chunk, size in zip(chunks, array.shape): + for chunk, size in zip(chunks, array.shape, strict=False): new_chunks.append(chunk if size > 1 else (1,)) return chunkmanager.from_array(array, new_chunks) # type: ignore[arg-type] @@ -1399,7 +1399,7 @@ def create_mask( base_mask = _masked_result_drop_slice(key, data) slice_shape = tuple( np.arange(*k.indices(size)).size - for k, size in zip(key, shape) + for k, size in zip(key, shape, strict=False) if isinstance(k, slice) ) expanded_mask = base_mask[(Ellipsis,) + (np.newaxis,) * len(slice_shape)] @@ -1711,7 +1711,7 @@ def _convert_scalar(self, item): # a NumPy array. return to_0d_array(item) - def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]: + def _prepare_key(self, key: Any | tuple[Any, ...]) -> tuple[Any, ...]: if isinstance(key, tuple) and len(key) == 1: # unpack key so it can index a pandas.Index object (pandas.Index # objects don't like tuples) diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py index 1ba3c6f1675..eeaeb35aa9c 100644 --- a/xarray/core/iterators.py +++ b/xarray/core/iterators.py @@ -28,15 +28,9 @@ class LevelOrderIter(Iterator): -------- >>> from xarray.core.datatree import DataTree >>> from xarray.core.iterators import LevelOrderIter - >>> f = DataTree(name="f") - >>> b = DataTree(name="b", parent=f) - >>> a = DataTree(name="a", parent=b) - >>> d = DataTree(name="d", parent=b) - >>> c = DataTree(name="c", parent=d) - >>> e = DataTree(name="e", parent=d) - >>> g = DataTree(name="g", parent=f) - >>> i = DataTree(name="i", parent=g) - >>> h = DataTree(name="h", parent=i) + >>> f = DataTree.from_dict( + ... {"/b/a": None, "/b/d/c": None, "/b/d/e": None, "/g/h/i": None}, name="f" + ... ) >>> print(f) Group: / @@ -46,19 +40,19 @@ class LevelOrderIter(Iterator): │ ├── Group: /b/d/c │ └── Group: /b/d/e └── Group: /g - └── Group: /g/i - └── Group: /g/i/h + └── Group: /g/h + └── Group: /g/h/i >>> [node.name for node in LevelOrderIter(f)] - ['f', 'b', 'g', 'a', 'd', 'i', 'c', 'e', 'h'] + ['f', 'b', 'g', 'a', 'd', 'h', 'c', 'e', 'i'] >>> [node.name for node in LevelOrderIter(f, maxlevel=3)] - ['f', 'b', 'g', 'a', 'd', 'i'] + ['f', 'b', 'g', 'a', 'd', 'h'] >>> [ ... node.name ... for node in LevelOrderIter(f, filter_=lambda n: n.name not in ("e", "g")) ... ] - ['f', 'b', 'a', 'd', 'i', 'c', 'h'] + ['f', 'b', 'a', 'd', 'h', 'c', 'i'] >>> [node.name for node in LevelOrderIter(f, stop=lambda n: n.name == "d")] - ['f', 'b', 'g', 'a', 'i', 'h'] + ['f', 'b', 'g', 'a', 'h', 'i'] """ def __init__( diff --git a/xarray/core/merge.py b/xarray/core/merge.py index d1eaa43ad89..bd927a188df 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -66,7 +66,7 @@ def broadcast_dimension_size(variables: list[Variable]) -> dict[Hashable, int]: """ dims: dict[Hashable, int] = {} for var in variables: - for dim, size in zip(var.dims, var.shape): + for dim, size in zip(var.dims, var.shape, strict=True): if dim in dims and size != dims[dim]: raise ValueError(f"index {dim!r} not aligned") dims[dim] = size @@ -267,7 +267,7 @@ def merge_collected( index, other_index, variable, other_var, index_cmp_cache ): raise MergeError( - f"conflicting values/indexes on objects to be combined fo coordinate {name!r}\n" + f"conflicting values/indexes on objects to be combined for coordinate {name!r}\n" f"first index: {index!r}\nsecond index: {other_index!r}\n" f"first variable: {variable!r}\nsecond variable: {other_var!r}\n" ) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 187a93d322f..55e754010da 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -624,7 +624,7 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): # target dimensions dims = list(indexes_coords) - x, new_x = zip(*[indexes_coords[d] for d in dims]) + x, new_x = zip(*[indexes_coords[d] for d in dims], strict=True) destination = broadcast_variables(*new_x) # transpose to make the interpolated axis to the last position @@ -710,7 +710,9 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): _, rechunked = chunkmanager.unify_chunks(*args) - args = tuple(elem for pair in zip(rechunked, args[1::2]) for elem in pair) + args = tuple( + elem for pair in zip(rechunked, args[1::2], strict=True) for elem in pair + ) new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] @@ -798,11 +800,13 @@ def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=T # _localize expect var to be a Variable var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var) - indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)} + indexes_coords = { + _x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x, strict=True) + } # simple speed up for the local interpolation var, indexes_coords = _localize(var, indexes_coords) - x, new_x = zip(*[indexes_coords[d] for d in indexes_coords]) + x, new_x = zip(*[indexes_coords[d] for d in indexes_coords], strict=True) # put var back as a ndarray var = var.data diff --git a/xarray/core/options.py b/xarray/core/options.py index f31413a2a1a..a00aa363014 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Literal, TypedDict from xarray.core.utils import FrozenDict @@ -92,7 +92,7 @@ class T_Options(TypedDict): _DISPLAY_OPTIONS = frozenset(["text", "html"]) -def _positive_integer(value: int) -> bool: +def _positive_integer(value: Any) -> bool: return isinstance(value, int) and value > 0 diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 9c68ee3a1c5..12be026e539 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -29,7 +29,7 @@ class ExpectedDict(TypedDict): def unzip(iterable): - return zip(*iterable) + return zip(*iterable, strict=True) def assert_chunks_compatible(a: Dataset, b: Dataset): @@ -345,7 +345,7 @@ def _wrapper( converted_args = [ dataset_to_dataarray(arg) if is_array else arg - for is_array, arg in zip(arg_is_array, args) + for is_array, arg in zip(arg_is_array, args, strict=True) ] result = func(*converted_args, **kwargs) @@ -440,7 +440,10 @@ def _wrapper( merged_coordinates = merge([arg.coords for arg in aligned]).coords _, npargs = unzip( - sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) + sorted( + list(zip(xarray_indices, xarray_objs, strict=True)) + others, + key=lambda x: x[0], + ) ) # check that chunk sizes are compatible @@ -534,7 +537,7 @@ def _wrapper( # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index - chunk_index = dict(zip(ichunk.keys(), chunk_tuple)) + chunk_index = dict(zip(ichunk.keys(), chunk_tuple, strict=True)) blocked_args = [ ( @@ -544,7 +547,7 @@ def _wrapper( if isxr else arg ) - for isxr, arg in zip(is_xarray, npargs) + for isxr, arg in zip(is_xarray, npargs, strict=True) ] # raise nice error messages in _wrapper diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 2149a62dfb5..c084640e763 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -58,7 +58,7 @@ from xarray.core.types import SideOptions if typing.TYPE_CHECKING: - from xarray.core.types import CFTimeDatetime + from xarray.core.types import CFTimeDatetime, ResampleCompatible class CFTimeGrouper: @@ -75,7 +75,7 @@ class CFTimeGrouper: def __init__( self, - freq: str | BaseCFTimeOffset, + freq: ResampleCompatible | BaseCFTimeOffset, closed: SideOptions | None = None, label: SideOptions | None = None, origin: str | CFTimeDatetime = "start_day", diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index f7dd1210919..072012e5f51 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -133,7 +133,7 @@ def __repr__(self) -> str: attrs = [ "{k}->{v}{c}".format(k=k, v=w, c="(center)" if c else "") - for k, w, c in zip(self.dim, self.window, self.center) + for k, w, c in zip(self.dim, self.window, self.center, strict=True) ] return "{klass} [{attrs}]".format( klass=self.__class__.__name__, attrs=",".join(attrs) @@ -303,7 +303,7 @@ def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]: starts = stops - window0 starts[: window0 - offset] = 0 - for label, start, stop in zip(self.window_labels, starts, stops): + for label, start, stop in zip(self.window_labels, starts, stops, strict=True): window = self.obj.isel({dim0: slice(start, stop)}) counts = window.count(dim=[dim0]) @@ -424,7 +424,9 @@ def _construct( attrs=attrs, name=obj.name, ) - return result.isel({d: slice(None, None, s) for d, s in zip(self.dim, strides)}) + return result.isel( + {d: slice(None, None, s) for d, s in zip(self.dim, strides, strict=True)} + ) def reduce( self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any @@ -520,7 +522,7 @@ def _counts(self, keep_attrs: bool | None) -> DataArray: counts = ( self.obj.notnull(keep_attrs=keep_attrs) .rolling( - {d: w for d, w in zip(self.dim, self.window)}, + {d: w for d, w in zip(self.dim, self.window, strict=True)}, center={d: self.center[i] for i, d in enumerate(self.dim)}, ) .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs) @@ -887,7 +889,7 @@ def construct( # Need to stride coords as well. TODO: is there a better way? coords = self.obj.isel( - {d: slice(None, None, s) for d, s in zip(self.dim, strides)} + {d: slice(None, None, s) for d, s in zip(self.dim, strides, strict=True)} ).coords attrs = self.obj.attrs if keep_attrs else {} diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 77e7ed23a51..84ce392ad32 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -86,6 +86,12 @@ def parent(self) -> Tree | None: """Parent of this node.""" return self._parent + @parent.setter + def parent(self: Tree, new_parent: Tree) -> None: + raise AttributeError( + "Cannot set parent attribute directly, you must modify the children of the other node instead using dict-like syntax" + ) + def _set_parent( self, new_parent: Tree | None, child_name: str | None = None ) -> None: @@ -553,14 +559,14 @@ def _set_item( else: current_node._set(name, item) - def __delitem__(self: Tree, key: str): + def __delitem__(self: Tree, key: str) -> None: """Remove a child node from this tree object.""" if key in self.children: child = self._children[key] del self._children[key] child.orphan() else: - raise KeyError("Cannot delete") + raise KeyError(key) def same_tree(self, other: Tree) -> bool: """True if other node is in the same tree as this node.""" diff --git a/xarray/core/types.py b/xarray/core/types.py index 3eb97f86c4a..34b6029ee15 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -3,6 +3,7 @@ import datetime import sys from collections.abc import Callable, Collection, Hashable, Iterator, Mapping, Sequence +from types import EllipsisType from typing import ( TYPE_CHECKING, Any, @@ -42,13 +43,22 @@ from xarray.core.dataset import Dataset from xarray.core.indexes import Index, Indexes from xarray.core.utils import Frozen - from xarray.core.variable import Variable - from xarray.groupers import TimeResampler + from xarray.core.variable import IndexVariable, Variable + from xarray.groupers import Grouper, TimeResampler + + GroupInput: TypeAlias = ( + str + | DataArray + | IndexVariable + | Sequence[Hashable] + | Mapping[Any, Grouper] + | None + ) try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray + DaskArray = np.ndarray # type: ignore[misc, assignment, unused-ignore] try: from cubed import Array as CubedArray @@ -188,7 +198,7 @@ def copy( # Don't change to Hashable | Collection[Hashable] # Read: https://github.com/pydata/xarray/issues/6142 -Dims = Union[str, Collection[Hashable], "ellipsis", None] +Dims = Union[str, Collection[Hashable], EllipsisType, None] # FYI in some cases we don't allow `None`, which this doesn't take account of. # FYI the `str` is for a size string, e.g. "16MB", supported by dask. @@ -308,3 +318,5 @@ def copy( Bins = Union[ int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index ] + +ResampleCompatible: TypeAlias = str | datetime.timedelta | pd.Timedelta | pd.DateOffset diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 008e1d101c9..68d17fc3614 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -63,6 +63,7 @@ ) from enum import Enum from pathlib import Path +from types import EllipsisType from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, TypeVar, overload import numpy as np @@ -185,7 +186,7 @@ def equivalent(first: T, second: T) -> bool: def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool: if len(first) != len(second): return False - for f, s in zip(first, second): + for f, s in zip(first, second, strict=True): if not equivalent(f, s): return False return True @@ -839,7 +840,7 @@ def parse_dims( *, check_exists: bool = True, replace_none: Literal[False], -) -> tuple[Hashable, ...] | None | ellipsis: ... +) -> tuple[Hashable, ...] | None | EllipsisType: ... def parse_dims( @@ -848,7 +849,7 @@ def parse_dims( *, check_exists: bool = True, replace_none: bool = True, -) -> tuple[Hashable, ...] | None | ellipsis: +) -> tuple[Hashable, ...] | None | EllipsisType: """Parse one or more dimensions. A single dimension must be always a str, multiple dimensions @@ -900,7 +901,7 @@ def parse_ordered_dims( *, check_exists: bool = True, replace_none: Literal[False], -) -> tuple[Hashable, ...] | None | ellipsis: ... +) -> tuple[Hashable, ...] | None | EllipsisType: ... def parse_ordered_dims( @@ -909,7 +910,7 @@ def parse_ordered_dims( *, check_exists: bool = True, replace_none: bool = True, -) -> tuple[Hashable, ...] | None | ellipsis: +) -> tuple[Hashable, ...] | None | EllipsisType: """Parse one or more dimensions. A single dimension must be always a str, multiple dimensions @@ -936,7 +937,7 @@ def parse_ordered_dims( Input dimensions as a tuple. """ if dim is not None and dim is not ... and not isinstance(dim, str) and ... in dim: - dims_set: set[Hashable | ellipsis] = set(dim) + dims_set: set[Hashable | EllipsisType] = set(dim) all_dims_set = set(all_dims) if check_exists: _check_dims(dims_set, all_dims_set) @@ -991,7 +992,7 @@ def __get__(self, obj: None | object, cls) -> type[_Accessor] | _Accessor: if obj is None: return self._accessor - return self._accessor(obj) # type: ignore # assume it is a valid accessor! + return self._accessor(obj) # type: ignore[call-arg] # assume it is a valid accessor! # Singleton type, as per https://github.com/python/typing/pull/240 diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a74fb4d8ce9..d84a03c3677 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -8,6 +8,7 @@ from collections.abc import Callable, Hashable, Mapping, Sequence from datetime import timedelta from functools import partial +from types import EllipsisType from typing import TYPE_CHECKING, Any, NoReturn, cast import numpy as np @@ -318,7 +319,7 @@ def as_compatible_data( # validate whether the data is valid data types. data = np.asarray(data) - if isinstance(data, np.ndarray) and data.dtype.kind in "OMm": + if data.dtype.kind in "OMm": data = _possibly_convert_objects(data) return _maybe_wrap_data(data) @@ -646,7 +647,7 @@ def _broadcast_indexes(self, key): # If all key is 1-dimensional and there are no duplicate labels, # key can be mapped as an OuterIndexer. dims = [] - for k, d in zip(key, self.dims): + for k, d in zip(key, self.dims, strict=True): if isinstance(k, Variable): if len(k.dims) > 1: return self._broadcast_indexes_vectorized(key) @@ -660,13 +661,15 @@ def _broadcast_indexes(self, key): def _broadcast_indexes_basic(self, key): dims = tuple( - dim for k, dim in zip(key, self.dims) if not isinstance(k, integer_types) + dim + for k, dim in zip(key, self.dims, strict=True) + if not isinstance(k, integer_types) ) return dims, BasicIndexer(key), None def _validate_indexers(self, key): """Make sanity checks""" - for dim, k in zip(self.dims, key): + for dim, k in zip(self.dims, key, strict=True): if not isinstance(k, BASIC_INDEXING_TYPES): if not isinstance(k, Variable): if not is_duck_array(k): @@ -705,7 +708,7 @@ def _broadcast_indexes_outer(self, key): # drop dim if k is integer or if k is a 0d dask array dims = tuple( k.dims[0] if isinstance(k, Variable) else dim - for k, dim in zip(key, self.dims) + for k, dim in zip(key, self.dims, strict=True) if (not isinstance(k, integer_types) and not is_0d_dask_array(k)) ) @@ -728,7 +731,7 @@ def _broadcast_indexes_outer(self, key): def _broadcast_indexes_vectorized(self, key): variables = [] out_dims_set = OrderedSet() - for dim, value in zip(self.dims, key): + for dim, value in zip(self.dims, key, strict=True): if isinstance(value, slice): out_dims_set.add(dim) else: @@ -750,7 +753,7 @@ def _broadcast_indexes_vectorized(self, key): variable_dims.update(variable.dims) slices = [] - for i, (dim, value) in enumerate(zip(self.dims, key)): + for i, (dim, value) in enumerate(zip(self.dims, key, strict=True)): if isinstance(value, slice): if dim in variable_dims: # We only convert slice objects to variables if they share @@ -1133,7 +1136,7 @@ def _pad_options_dim_to_index( if fill_with_shape: return [ (n, n) if d not in pad_option else pad_option[d] - for d, n in zip(self.dims, self.data.shape) + for d, n in zip(self.dims, self.data.shape, strict=True) ] return [(0, 0) if d not in pad_option else pad_option[d] for d in self.dims] @@ -1287,7 +1290,7 @@ def roll(self, shifts=None, **shifts_kwargs): @deprecate_dims def transpose( self, - *dim: Hashable | ellipsis, + *dim: Hashable | EllipsisType, missing_dims: ErrorOptionsWithWarn = "raise", ) -> Self: """Return a new Variable object with transposed dimensions. @@ -1376,7 +1379,7 @@ def set_dims(self, dim, shape=None): # writeable if possible expanded_data = self.data elif shape is not None: - dims_map = dict(zip(dim, shape)) + dims_map = dict(zip(dim, shape, strict=True)) tmp_shape = tuple(dims_map[d] for d in expanded_dims) expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape) else: @@ -1490,7 +1493,7 @@ def _unstack_once( dim: Hashable, fill_value=dtypes.NA, sparse: bool = False, - ) -> Self: + ) -> Variable: """ Unstacks this variable given an index to unstack and the name of the dimension to which the index refers. @@ -1526,13 +1529,13 @@ def _unstack_once( # unstacking a dense multitindexed array to a sparse array from sparse import COO - codes = zip(*index.codes) + codes = zip(*index.codes, strict=True) if reordered.ndim == 1: indexes = codes else: sizes = itertools.product(*[range(s) for s in reordered.shape[:-1]]) tuple_indexes = itertools.product(sizes, codes) - indexes = map(lambda x: list(itertools.chain(*x)), tuple_indexes) # type: ignore + indexes = map(lambda x: list(itertools.chain(*x)), tuple_indexes) # type: ignore[assignment] data = COO( coords=np.array(list(indexes)).T, @@ -1550,10 +1553,10 @@ def _unstack_once( # case the destinations will be NaN / zero. data[(..., *indexer)] = reordered - return self._replace(dims=new_dims, data=data) + return self.to_base_variable()._replace(dims=new_dims, data=data) @partial(deprecate_dims, old_name="dimensions") - def unstack(self, dim=None, **dim_kwargs): + def unstack(self, dim=None, **dim_kwargs) -> Variable: """ Unstack an existing dimension into multiple new dimensions. @@ -1657,7 +1660,7 @@ def reduce( # type: ignore[override] _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs ) - # Noe that the call order for Variable.mean is + # Note that the call order for Variable.mean is # Variable.mean -> NamedArray.mean -> Variable.reduce # -> NamedArray.reduce result = super().reduce( @@ -2060,7 +2063,9 @@ def rolling_window( if utils.is_scalar(dim): for name, arg in zip( - ["window", "window_dim", "center"], [window, window_dim, center] + ["window", "window_dim", "center"], + [window, window_dim, center], + strict=True, ): if not utils.is_scalar(arg): raise ValueError( @@ -2088,7 +2093,7 @@ def rolling_window( ) pads = {} - for d, win, cent in zip(dim, window, center): + for d, win, cent in zip(dim, window, center, strict=True): if cent: start = win // 2 # 10 -> 5, 9 -> 4 end = win - 1 - start @@ -2398,7 +2403,7 @@ def _unravel_argminmax( result = { d: Variable(dims=result_dims, data=i) - for d, i in zip(dim, result_unravelled_indices) + for d, i in zip(dim, result_unravelled_indices, strict=True) } if keep_attrs is None: @@ -2869,7 +2874,7 @@ def _unified_dims(variables): var_dims = var.dims _raise_if_any_duplicate_dimensions(var_dims, err_context="Broadcasting") - for d, s in zip(var_dims, var.shape): + for d, s in zip(var_dims, var.shape, strict=True): if d not in all_dims: all_dims[d] = s elif all_dims[d] != s: @@ -2997,7 +3002,7 @@ def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, in last_used = {} scalar_vars = {k for k, v in variables.items() if not v.dims} for k, var in variables.items(): - for dim, size in zip(var.dims, var.shape): + for dim, size in zip(var.dims, var.shape, strict=True): if dim in scalar_vars: raise ValueError( f"dimension {dim!r} already exists as a scalar variable" diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 8cb90ac1b2b..2c6e7d4282a 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -182,7 +182,7 @@ def _weight_check(w): if is_duck_dask_array(weights.data): # assign to copy - else the check is not triggered weights = weights.copy( - data=weights.data.map_blocks(_weight_check, dtype=weights.dtype), + data=weights.data.map_blocks(_weight_check, dtype=weights.dtype), # type: ignore[call-arg, arg-type] deep=False, ) @@ -264,7 +264,9 @@ def _sum_of_squares( demeaned = da - da.weighted(self.weights).mean(dim=dim) - return self._reduce((demeaned**2), self.weights, dim=dim, skipna=skipna) + # TODO: unsure why mypy complains about these being DataArray return types + # rather than T_DataArray? + return self._reduce((demeaned**2), self.weights, dim=dim, skipna=skipna) # type: ignore[return-value] def _weighted_sum( self, @@ -274,7 +276,7 @@ def _weighted_sum( ) -> T_DataArray: """Reduce a DataArray by a weighted ``sum`` along some dimension(s).""" - return self._reduce(da, self.weights, dim=dim, skipna=skipna) + return self._reduce(da, self.weights, dim=dim, skipna=skipna) # type: ignore[return-value] def _weighted_mean( self, diff --git a/xarray/datatree_/docs/source/conf.py b/xarray/datatree_/docs/source/conf.py index 430dbb5bf6d..c32f2b126ed 100644 --- a/xarray/datatree_/docs/source/conf.py +++ b/xarray/datatree_/docs/source/conf.py @@ -17,9 +17,9 @@ import os import sys -import sphinx_autosummary_accessors # type: ignore +import sphinx_autosummary_accessors # type: ignore[import-not-found] -import datatree # type: ignore +import datatree # type: ignore[import-not-found] # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the diff --git a/xarray/datatree_/docs/source/data-structures.rst b/xarray/datatree_/docs/source/data-structures.rst index 02e4a31f688..90b786701cc 100644 --- a/xarray/datatree_/docs/source/data-structures.rst +++ b/xarray/datatree_/docs/source/data-structures.rst @@ -40,7 +40,7 @@ stored under hashable keys), and so has the same key properties: - ``dims``: a dictionary mapping of dimension names to lengths, for the variables in this node, - ``data_vars``: a dict-like container of DataArrays corresponding to variables in this node, - ``coords``: another dict-like container of DataArrays, corresponding to coordinate variables in this node, -- ``attrs``: dict to hold arbitary metadata relevant to data in this node. +- ``attrs``: dict to hold arbitrary metadata relevant to data in this node. A single ``DataTree`` object acts much like a single ``Dataset`` object, and has a similar set of dict-like methods defined upon it. However, ``DataTree``'s can also contain other ``DataTree`` objects, so they can be thought of as nested dict-like diff --git a/xarray/datatree_/docs/source/hierarchical-data.rst b/xarray/datatree_/docs/source/hierarchical-data.rst index d4f58847718..ceb3fc46b44 100644 --- a/xarray/datatree_/docs/source/hierarchical-data.rst +++ b/xarray/datatree_/docs/source/hierarchical-data.rst @@ -133,7 +133,7 @@ We can add Herbert to the family tree without displacing Homer by :py:meth:`~Dat .. note:: This example shows a minor subtlety - the returned tree has Homer's brother listed as ``"Herbert"``, - but the original node was named "Herbert". Not only are names overriden when stored as keys like this, + but the original node was named "Herbert". Not only are names overridden when stored as keys like this, but the new node is a copy, so that the original node that was reference is unchanged (i.e. ``herbert.name == "Herb"`` still). In other words, nodes are copied into trees, not inserted into them. This is intentional, and mirrors the behaviour when storing named ``xarray.DataArray`` objects inside datasets. diff --git a/xarray/groupers.py b/xarray/groupers.py index f70cad655e8..e4cb884e6de 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -14,14 +14,20 @@ import numpy as np import pandas as pd -from xarray.coding.cftime_offsets import _new_to_legacy_freq +from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper -from xarray.core.types import Bins, DatetimeLike, GroupIndices, SideOptions +from xarray.core.types import ( + Bins, + DatetimeLike, + GroupIndices, + ResampleCompatible, + SideOptions, +) from xarray.core.variable import Variable __all__ = [ @@ -184,7 +190,7 @@ def _factorize_unique(self) -> EncodedGroups: raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) - codes = self.group.copy(data=codes_.reshape(self.group.shape)) + codes = self.group.copy(data=codes_.reshape(self.group.shape), deep=False) unique_coord = Variable( dims=codes.name, data=unique_values, attrs=self.group.attrs ) @@ -212,7 +218,7 @@ def _factorize_dummy(self) -> EncodedGroups: full_index = pd.RangeIndex(self.group.size) coords = Coordinates() else: - codes = self.group.copy(data=size_range) + codes = self.group.copy(data=size_range, deep=False) unique_coord = self.group.variable.to_base_variable() full_index = self.group_as_index if isinstance(full_index, pd.MultiIndex): @@ -336,7 +342,7 @@ class TimeResampler(Resampler): Attributes ---------- - freq : str + freq : str, datetime.timedelta, pandas.Timestamp, or pandas.DateOffset Frequency to resample to. See `Pandas frequency aliases `_ for a list of possible values. @@ -358,7 +364,7 @@ class TimeResampler(Resampler): An offset timedelta added to the origin. """ - freq: str + freq: ResampleCompatible closed: SideOptions | None = field(default=None) label: SideOptions | None = field(default=None) origin: str | DatetimeLike = field(default="start_day") @@ -388,6 +394,12 @@ def _init_properties(self, group: T_Group) -> None: offset=offset, ) else: + if isinstance(self.freq, BaseCFTimeOffset): + raise ValueError( + "'BaseCFTimeOffset' resample frequencies are only supported " + "when resampling a 'CFTimeIndex'" + ) + self.index_grouper = pd.Grouper( # TODO remove once requiring pandas >= 2.2 freq=_new_to_legacy_freq(self.freq), @@ -431,14 +443,14 @@ def factorize(self, group: T_Group) -> EncodedGroups: full_index, first_items, codes_ = self._get_index_and_items() sbins = first_items.values.astype(np.int64) group_indices: GroupIndices = tuple( - [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True)] + [slice(sbins[-1], None)] ) unique_coord = Variable( dims=group.name, data=first_items.index, attrs=group.attrs ) - codes = group.copy(data=codes_.reshape(group.shape)) + codes = group.copy(data=codes_.reshape(group.shape), deep=False) return EncodedGroups( codes=codes, diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 57b17385558..a7d7ed7994f 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -3,7 +3,7 @@ import sys from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from enum import Enum -from types import ModuleType +from types import EllipsisType, ModuleType from typing import ( TYPE_CHECKING, Any, @@ -91,7 +91,7 @@ def dtype(self) -> _DType_co: ... # TODO: np.array_api was bugged and didn't allow (None,), but should! # https://github.com/numpy/numpy/pull/25022 # https://github.com/data-apis/array-api/pull/674 -_IndexKey = Union[int, slice, "ellipsis"] +_IndexKey = Union[int, slice, EllipsisType] _IndexKeys = tuple[_IndexKey, ...] # tuple[Union[_IndexKey, None], ...] _IndexKeyLike = Union[_IndexKey, _IndexKeys] diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c5841f6913e..0d1a50a8d3c 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -5,6 +5,7 @@ import sys import warnings from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from types import EllipsisType from typing import ( TYPE_CHECKING, Any, @@ -736,14 +737,14 @@ def chunksizes( """ data = self._data if isinstance(data, _chunkedarray): - return dict(zip(self.dims, data.chunks)) + return dict(zip(self.dims, data.chunks, strict=True)) else: return {} @property def sizes(self) -> dict[_Dim, _IntOrUnknown]: """Ordered mapping from dimension names to lengths.""" - return dict(zip(self.dims, self.shape)) + return dict(zip(self.dims, self.shape, strict=True)) def chunk( self, @@ -947,7 +948,7 @@ def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]: _attrs = self.attrs return tuple( cast("T_NamedArrayInteger", self._new((dim,), nz, _attrs)) - for nz, dim in zip(nonzeros, self.dims) + for nz, dim in zip(nonzeros, self.dims, strict=True) ) def __repr__(self) -> str: @@ -996,7 +997,7 @@ def _to_dense(self) -> NamedArray[Any, _DType_co]: def permute_dims( self, - *dim: Iterable[_Dim] | ellipsis, + *dim: Iterable[_Dim] | EllipsisType, missing_dims: ErrorOptionsWithWarn = "raise", ) -> NamedArray[Any, _DType_co]: """Return a new object with transposed dimensions. @@ -1037,8 +1038,8 @@ def permute_dims( # or dims are in same order return self.copy(deep=False) - axes_result = self.get_axis_num(dims) - axes = (axes_result,) if isinstance(axes_result, int) else axes_result + axes = self.get_axis_num(dims) + assert isinstance(axes, tuple) return permute_dims(self, axes) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index e3a4f6ba1ad..606e72acd0e 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -20,8 +20,8 @@ from dask.array.core import Array as DaskArray from dask.typing import DaskCollection except ImportError: - DaskArray = NDArray # type: ignore - DaskCollection: Any = NDArray # type: ignore + DaskArray = NDArray # type: ignore[assignment, misc] + DaskCollection: Any = NDArray # type: ignore[no-redef] from xarray.namedarray._typing import _Dim, duckarray diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index ae10c3e9920..b759f0bb944 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -210,7 +210,7 @@ def _prepare_plot1d_data( plts.update( {k: darray.coords[v] for k, v in coords_to_plot.items() if v is not None} ) - plts = dict(zip(plts.keys(), broadcast(*(plts.values())))) + plts = dict(zip(plts.keys(), broadcast(*(plts.values())), strict=True)) return plts @@ -1089,7 +1089,9 @@ def _add_labels( """Set x, y, z labels.""" add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels axes: tuple[Literal["x", "y", "z"], ...] = ("x", "y", "z") - for axis, add_label, darray, suffix in zip(axes, add_labels, darrays, suffixes): + for axis, add_label, darray, suffix in zip( + axes, add_labels, darrays, suffixes, strict=True + ): if darray is None: continue diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 8ad3b4a6296..04d298efcce 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -697,7 +697,7 @@ def wrapper(dataset_plotfunc: F) -> F: dataset_plotfunc.__doc__ = ds_doc return dataset_plotfunc - return wrapper + return wrapper # type: ignore[return-value] def _normalize_args( @@ -737,7 +737,7 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr coords[key] = darray dims.update(darray.dims) - # Trim dataset from unneccessary dims: + # Trim dataset from unnecessary dims: ds_trimmed = ds.drop_dims(ds.sizes.keys() - dims) # TODO: Use ds.dims in the future # The dataarray has to include all the dims. Broadcast to that shape diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 4c0d9b96a03..4e43ad2826c 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -362,7 +362,7 @@ def map_dataarray( rgb=kwargs.get("rgb", None), ) - for d, ax in zip(self.name_dicts.flat, self.axs.flat): + for d, ax in zip(self.name_dicts.flat, self.axs.flat, strict=True): # None is the sentinel value if d is not None: subset = self.data.loc[d] @@ -505,7 +505,10 @@ def map_plot1d( # Plot the data for each subplot: for add_lbls, d, ax in zip( - add_labels_.reshape((self.axs.size, -1)), name_dicts.flat, self.axs.flat + add_labels_.reshape((self.axs.size, -1)), + name_dicts.flat, + self.axs.flat, + strict=True, ): func_kwargs["add_labels"] = add_lbls # None is the sentinel value @@ -571,7 +574,7 @@ def map_dataarray_line( ) -> T_FacetGrid: from xarray.plot.dataarray_plot import _infer_line_data - for d, ax in zip(self.name_dicts.flat, self.axs.flat): + for d, ax in zip(self.name_dicts.flat, self.axs.flat, strict=True): # None is the sentinel value if d is not None: subset = self.data.loc[d] @@ -638,7 +641,7 @@ def map_dataset( raise ValueError("Please provide scale.") # TODO: come up with an algorithm for reasonable scale choice - for d, ax in zip(self.name_dicts.flat, self.axs.flat): + for d, ax in zip(self.name_dicts.flat, self.axs.flat, strict=True): # None is the sentinel value if d is not None: subset = self.data.loc[d] @@ -672,7 +675,7 @@ def _finalize_grid(self, *axlabels: Hashable) -> None: self.set_titles() self.fig.tight_layout() - for ax, namedict in zip(self.axs.flat, self.name_dicts.flat): + for ax, namedict in zip(self.axs.flat, self.name_dicts.flat, strict=True): if namedict is None: ax.set_visible(False) @@ -824,7 +827,7 @@ def _set_lims( # Set limits: for ax in self.axs.flat: for (axis, data_limit), parameter_limit in zip( - lims_largest.items(), (x, y, z) + lims_largest.items(), (x, y, z), strict=True ): set_lim = getattr(ax, f"set_{axis}lim", None) if set_lim: @@ -834,7 +837,7 @@ def set_axis_labels(self, *axlabels: Hashable) -> None: """Set axis labels on the left column and bottom row of the grid.""" from xarray.core.dataarray import DataArray - for var, axis in zip(axlabels, ["x", "y", "z"]): + for var, axis in zip(axlabels, ["x", "y", "z"], strict=False): if var is not None: if isinstance(var, DataArray): getattr(self, f"set_{axis}labels")(label_from_attrs(var)) @@ -893,7 +896,7 @@ def set_titles( nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) if self._single_group: - for d, ax in zip(self.name_dicts.flat, self.axs.flat): + for d, ax in zip(self.name_dicts.flat, self.axs.flat, strict=True): # Only label the ones with data if d is not None: coord, value = list(d.items()).pop() @@ -902,7 +905,7 @@ def set_titles( else: # The row titles on the right edge of the grid for index, (ax, row_name, handle) in enumerate( - zip(self.axs[:, -1], self.row_names, self.row_labels) + zip(self.axs[:, -1], self.row_names, self.row_labels, strict=True) ): title = nicetitle(coord=self._row_var, value=row_name, maxchar=maxchar) if not handle: @@ -921,7 +924,7 @@ def set_titles( # The column titles on the top row for index, (ax, col_name, handle) in enumerate( - zip(self.axs[0, :], self.col_names, self.col_labels) + zip(self.axs[0, :], self.col_names, self.col_labels, strict=True) ): title = nicetitle(coord=self._col_var, value=col_name, maxchar=maxchar) if not handle: @@ -992,7 +995,7 @@ def map( """ import matplotlib.pyplot as plt - for ax, namedict in zip(self.axs.flat, self.name_dicts.flat): + for ax, namedict in zip(self.axs.flat, self.name_dicts.flat, strict=True): if namedict is not None: data = self.data.loc[namedict] plt.sca(ax) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a037123a46f..22d447316ca 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -45,7 +45,7 @@ try: import matplotlib.pyplot as plt except ImportError: - plt: Any = None # type: ignore + plt: Any = None # type: ignore[no-redef] ROBUST_PERCENTILE = 2.0 @@ -577,8 +577,12 @@ def _interval_to_double_bound_points( xarray1 = np.array([x.left for x in xarray]) xarray2 = np.array([x.right for x in xarray]) - xarray_out = np.array(list(itertools.chain.from_iterable(zip(xarray1, xarray2)))) - yarray_out = np.array(list(itertools.chain.from_iterable(zip(yarray, yarray)))) + xarray_out = np.array( + list(itertools.chain.from_iterable(zip(xarray1, xarray2, strict=True))) + ) + yarray_out = np.array( + list(itertools.chain.from_iterable(zip(yarray, yarray, strict=True))) + ) return xarray_out, yarray_out @@ -1148,7 +1152,7 @@ def _get_color_and_size(value): kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) kw.update(kwargs) - for val, lab in zip(values, label_values): + for val, lab in zip(values, label_values, strict=True): color, size = _get_color_and_size(val) if isinstance(self, mpl.collections.PathCollection): @@ -1170,7 +1174,7 @@ def _legend_add_subtitle(handles, labels, text): if text and len(handles) > 1: # Create a blank handle that's not visible, the - # invisibillity will be used to discern which are subtitles + # invisibility will be used to discern which are subtitles # or not: blank_handle = plt.Line2D([], [], label=text) blank_handle.set_visible(False) @@ -1347,7 +1351,7 @@ def _parse_size( widths = np.asarray(min_width + scl * (max_width - min_width)) if scl.mask.any(): widths[scl.mask] = 0 - sizes = dict(zip(levels, widths)) + sizes = dict(zip(levels, widths, strict=True)) return pd.Series(sizes) @@ -1606,7 +1610,7 @@ def _lookup(self) -> pd.Series: if self._values_unique is None: raise ValueError("self.data can't be None.") - return pd.Series(dict(zip(self._values_unique, self._data_unique))) + return pd.Series(dict(zip(self._values_unique, self._data_unique, strict=True))) def _lookup_arr(self, x) -> np.ndarray: # Use reindex to be less sensitive to float errors. reindex only @@ -1818,7 +1822,9 @@ def _guess_coords_to_plot( # one of related mpl kwargs has been used. This should have similar behaviour as # * plt.plot(x, y) -> Multiple lines with different colors if y is 2d. # * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d. - for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs): + for k, dim, ign_kws in zip( + default_guess, available_coords, ignore_guess_kwargs, strict=False + ): if coords_to_plot.get(k, None) is None and all( kwargs.get(ign_kw, None) is None for ign_kw in ign_kws ): diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 3dec6a25616..3a0cc96b20d 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -211,7 +211,9 @@ def assert_identical(a, b, from_root=True): if isinstance(a, Variable): assert a.identical(b), formatting.diff_array_repr(a, b, "identical") elif isinstance(a, DataArray): - assert a.name == b.name + assert ( + a.name == b.name + ), f"DataArray names are different. L: {a.name}, R: {b.name}" assert a.identical(b), formatting.diff_array_repr(a, b, "identical") elif isinstance(a, Dataset | Variable): assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical") diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index b4d3871c229..0e43738ed99 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -134,7 +134,6 @@ def _importorskip( has_pint, requires_pint = _importorskip("pint") has_numexpr, requires_numexpr = _importorskip("numexpr") has_flox, requires_flox = _importorskip("flox") -has_pandas_ge_2_1, __ = _importorskip("pandas", "2.1") has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2") has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0") @@ -324,8 +323,8 @@ def create_test_data( obj["var4"] = ( "dim1", pd.Categorical( - np.random.choice( - list(string.ascii_lowercase[: np.random.randint(5)]), + rs.choice( + list(string.ascii_lowercase[: rs.randint(1, 5)]), size=dim_sizes[0], ) ), @@ -333,7 +332,7 @@ def create_test_data( if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: - numbers_values = np.random.randint(0, 3, _dims["dim3"], dtype="int64") + numbers_values = rs.randint(0, 3, _dims["dim3"], dtype="int64") obj.coords["numbers"] = ("dim3", numbers_values) obj.encoding = {"foo": "bar"} assert_writeable(obj) diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index a32b0e08bea..065e49a372a 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -139,6 +139,26 @@ def d(request, backend, type) -> DataArray | Dataset: raise ValueError +@pytest.fixture +def byte_attrs_dataset(): + """For testing issue #9407""" + null_byte = b"\x00" + other_bytes = bytes(range(1, 256)) + ds = Dataset({"x": 1}, coords={"x_coord": [1]}) + ds["x"].attrs["null_byte"] = null_byte + ds["x"].attrs["other_bytes"] = other_bytes + + expected = ds.copy() + expected["x"].attrs["null_byte"] = "" + expected["x"].attrs["other_bytes"] = other_bytes.decode(errors="replace") + + return { + "input": ds, + "expected": expected, + "h5netcdf_error": r"Invalid value provided for attribute .*: .*\. Null characters .*", + } + + @pytest.fixture(scope="module") def create_test_datatree(): """ @@ -176,14 +196,17 @@ def _create_test_datatree(modify=lambda ds: ds): set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) - # Avoid using __init__ so we can independently test it - root: DataTree = DataTree(data=root_data) - set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=set1) - DataTree(name="set2", parent=set1) - set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) - DataTree(name="set1", parent=set2) - DataTree(name="set3", parent=root) + root = DataTree.from_dict( + { + "/": root_data, + "/set1": set1_data, + "/set1/set1": None, + "/set1/set2": None, + "/set2": set2_data, + "/set2/set1": None, + "/set3": None, + } + ) return root diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 686bce943fa..587f43a5d7f 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -97,6 +97,7 @@ def test_field_access(self, field) -> None: actual = getattr(self.data.time.dt, field) else: actual = getattr(self.data.time.dt, field) + assert not isinstance(actual.variable, xr.IndexVariable) assert expected.dtype == actual.dtype assert_identical(expected, actual) @@ -141,6 +142,17 @@ def test_strftime(self) -> None: "2000-01-01 01:00:00" == self.data.time.dt.strftime("%Y-%m-%d %H:%M:%S")[1] ) + @requires_cftime + @pytest.mark.parametrize( + "calendar,expected", + [("standard", 366), ("noleap", 365), ("360_day", 360), ("all_leap", 366)], + ) + def test_days_in_year(self, calendar, expected) -> None: + assert ( + self.data.convert_calendar(calendar, align_on="year").time.dt.days_in_year + == expected + ).all() + def test_not_datetime_type(self) -> None: nontime_data = self.data.copy() int_data = np.arange(len(self.data.time)).astype("int8") @@ -176,6 +188,7 @@ def test_not_datetime_type(self) -> None: "is_year_start", "is_year_end", "is_leap_year", + "days_in_year", ], ) def test_dask_field_access(self, field) -> None: @@ -697,3 +710,46 @@ def test_cftime_round_accessor( result = cftime_rounding_dataarray.dt.round(freq) assert_identical(result, expected) + + +@pytest.mark.parametrize( + "use_cftime", + [False, pytest.param(True, marks=requires_cftime)], + ids=lambda x: f"use_cftime={x}", +) +@pytest.mark.parametrize( + "use_dask", + [False, pytest.param(True, marks=requires_dask)], + ids=lambda x: f"use_dask={x}", +) +def test_decimal_year(use_cftime, use_dask) -> None: + year = 2000 + periods = 10 + freq = "h" + + shape = (2, 5) + dims = ["x", "y"] + hours_in_year = 24 * 366 + + times = xr.date_range(f"{year}", periods=periods, freq=freq, use_cftime=use_cftime) + + da = xr.DataArray(times.values.reshape(shape), dims=dims) + + if use_dask: + da = da.chunk({"y": 2}) + # Computing the decimal year for a cftime datetime array requires a + # number of small computes (6): + # - 4x one compute per .dt accessor call (requires inspecting one + # object-dtype array element to see if it is time-like) + # - 2x one compute per calendar inference (requires inspecting one + # array element to read off the calendar) + max_computes = 6 * use_cftime + with raise_if_dask_computes(max_computes=max_computes): + result = da.dt.decimal_year + else: + result = da.dt.decimal_year + + expected = xr.DataArray( + year + np.arange(periods).reshape(shape) / hours_in_year, dims=dims + ) + xr.testing.assert_equal(result, expected) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 03c77e2365e..3ebb67cf1f3 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -68,7 +68,7 @@ def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: expected = xr.broadcast(np_arr, np_arr2) actual = xr.broadcast(xp_arr, xp_arr2) assert len(actual) == len(expected) - for a, e in zip(actual, expected): + for a, e in zip(actual, expected, strict=True): assert isinstance(a.data, Array) assert_equal(a, e) diff --git a/xarray/tests/test_assertions.py b/xarray/tests/test_assertions.py index 20b5e163662..2f5a8739b28 100644 --- a/xarray/tests/test_assertions.py +++ b/xarray/tests/test_assertions.py @@ -11,7 +11,7 @@ try: from dask.array import from_array as dask_from_array except ImportError: - dask_from_array = lambda x: x # type: ignore + dask_from_array = lambda x: x # type: ignore[assignment, misc] try: import pint diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index c755924f583..13258fcf6ea 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -595,7 +595,7 @@ def test_roundtrip_cftime_datetime_data(self) -> None: assert actual.t.encoding["calendar"] == expected_calendar def test_roundtrip_timedelta_data(self) -> None: - time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) + time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) # type: ignore[arg-type, unused-ignore] expected = Dataset({"td": ("td", time_deltas), "td0": time_deltas[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) @@ -1223,7 +1223,9 @@ def test_invalid_dataarray_names_raise(self) -> None: ve = (ValueError, "string must be length 1 or") data = np.random.random((2, 2)) da = xr.DataArray(data) - for name, (error, msg) in zip([0, (4, 5), True, ""], [te, te, te, ve]): + for name, (error, msg) in zip( + [0, (4, 5), True, ""], [te, te, te, ve], strict=True + ): ds = Dataset({name: da}) with pytest.raises(error) as excinfo: with self.roundtrip(ds): @@ -1404,6 +1406,13 @@ def test_refresh_from_disk(self) -> None: a.close() b.close() + def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None: + # test for issue #9407 + input = byte_attrs_dataset["input"] + expected = byte_attrs_dataset["expected"] + with self.roundtrip(input) as actual: + assert_identical(actual, expected) + _counter = itertools.count() @@ -1701,7 +1710,7 @@ def test_base_chunking_uses_disk_chunk_sizes(self) -> None: open_kwargs={"chunks": {}}, ) as ds: for chunksizes, expected in zip( - ds["image"].data.chunks, (1, y_chunksize, x_chunksize) + ds["image"].data.chunks, (1, y_chunksize, x_chunksize), strict=True ): assert all(np.asanyarray(chunksizes) == expected) @@ -2204,7 +2213,7 @@ def create_store(self): store_target, mode="w", **self.version_kwargs ) - def save(self, dataset, store_target, **kwargs): + def save(self, dataset, store_target, **kwargs): # type: ignore[override] return dataset.to_zarr(store=store_target, **kwargs, **self.version_kwargs) @contextlib.contextmanager @@ -3861,6 +3870,10 @@ def test_decode_utf8_warning(self) -> None: assert ds.title == title assert "attribute 'title' of h5netcdf object '/'" in str(w[0].message) + def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None: + with pytest.raises(ValueError, match=byte_attrs_dataset["h5netcdf_error"]): + super().test_byte_attrs(byte_attrs_dataset) + @requires_h5netcdf @requires_netCDF4 @@ -5030,9 +5043,10 @@ def test_extract_nc4_variable_encoding_netcdf4(self): var = xr.Variable(("x",), [1, 2, 3], {}, {"compression": "szlib"}) _extract_nc4_variable_encoding(var, backend="netCDF4", raise_on_invalid=True) + @pytest.mark.xfail def test_extract_h5nc_encoding(self) -> None: # not supported with h5netcdf (yet) - var = xr.Variable(("x",), [1, 2, 3], {}, {"least_sigificant_digit": 2}) + var = xr.Variable(("x",), [1, 2, 3], {}, {"least_significant_digit": 2}) with pytest.raises(ValueError, match=r"unexpected encoding"): _extract_nc4_variable_encoding(var, raise_on_invalid=True) @@ -5932,7 +5946,7 @@ def test_zarr_region_index_write(self, tmp_path): ds.to_zarr(tmp_path / "test.zarr") region: Mapping[str, slice] | Literal["auto"] - for region in [region_slice, "auto"]: # type: ignore + for region in [region_slice, "auto"]: # type: ignore[assignment] with patch.object( ZarrStore, "set_variables", diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index d4f8b7ed31d..3a4b1d76287 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -69,7 +69,7 @@ def open_dataset( class PassThroughBackendEntrypoint(xr.backends.BackendEntrypoint): """Access an object passed to the `open_dataset` method.""" - def open_dataset(self, dataset, *, drop_variables=None): + def open_dataset(self, dataset, *, drop_variables=None): # type: ignore[override] """Return the first argument.""" return dataset @@ -86,7 +86,7 @@ def explicit_chunks(chunks, shape): if isinstance(chunk, Number) else chunk ) - for chunk, size in zip(chunks, shape) + for chunk, size in zip(chunks, shape, strict=True) ) @@ -104,7 +104,9 @@ def create_dataset(self, shape, pref_chunks): self.var_name: xr.Variable( dims, np.empty(shape, dtype=np.dtype("V1")), - encoding={"preferred_chunks": dict(zip(dims, pref_chunks))}, + encoding={ + "preferred_chunks": dict(zip(dims, pref_chunks, strict=True)) + }, ) } ) @@ -164,7 +166,7 @@ def test_split_chunks(self, shape, pref_chunks, req_chunks): final = xr.open_dataset( initial, engine=PassThroughBackendEntrypoint, - chunks=dict(zip(initial[self.var_name].dims, req_chunks)), + chunks=dict(zip(initial[self.var_name].dims, req_chunks, strict=True)), ) self.check_dataset(initial, final, explicit_chunks(req_chunks, shape)) @@ -196,6 +198,6 @@ def test_join_chunks(self, shape, pref_chunks, req_chunks): final = xr.open_dataset( initial, engine=PassThroughBackendEntrypoint, - chunks=dict(zip(initial[self.var_name].dims, req_chunks)), + chunks=dict(zip(initial[self.var_name].dims, req_chunks, strict=True)), ) self.check_dataset(initial, final, explicit_chunks(req_chunks, shape)) diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 604f27317b9..e84c77e54ed 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -64,7 +64,7 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree): assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"] assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"] - enc["/not/a/group"] = {"foo": "bar"} # type: ignore + enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item] with pytest.raises(ValueError, match="unexpected encoding group.*"): original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) @@ -253,7 +253,7 @@ def test_zarr_encoding(self, tmpdir, simple_datatree): print(roundtrip_dt["/set2/a"].encoding) assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"] - enc["/not/a/group"] = {"foo": "bar"} # type: ignore + enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item] with pytest.raises(ValueError, match="unexpected encoding group.*"): original_dt.to_zarr(filepath, encoding=enc, engine="zarr") diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index cede3e66fcf..ab1ac4a06d9 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -53,7 +53,7 @@ def test_file_manager_autoclose(warn_for_unclosed_files) -> None: if warn_for_unclosed_files: ctx = pytest.warns(RuntimeWarning) else: - ctx = assert_no_warnings() # type: ignore + ctx = assert_no_warnings() # type: ignore[assignment] with set_options(warn_for_unclosed_files=warn_for_unclosed_files): with ctx: diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py index 5735e0327a0..fead97b7d57 100644 --- a/xarray/tests/test_backends_lru_cache.py +++ b/xarray/tests/test_backends_lru_cache.py @@ -33,7 +33,7 @@ def test_trivial() -> None: def test_invalid() -> None: with pytest.raises(TypeError): - LRUCache(maxsize=None) # type: ignore + LRUCache(maxsize=None) # type: ignore[arg-type] with pytest.raises(ValueError): LRUCache(maxsize=-1) diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index c4f2f51bd33..11e56e2adad 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -425,7 +425,9 @@ def test_neq(a, b): ] -@pytest.mark.parametrize(("a", "b"), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY), ids=_id_func) +@pytest.mark.parametrize( + ("a", "b"), zip(_EQ_TESTS_B, _EQ_TESTS_B_COPY, strict=True), ids=_id_func +) def test_eq(a, b): assert a == b @@ -572,7 +574,9 @@ def test_sub_error(offset, calendar): offset - initial -@pytest.mark.parametrize(("a", "b"), zip(_EQ_TESTS_A, _EQ_TESTS_B), ids=_id_func) +@pytest.mark.parametrize( + ("a", "b"), zip(_EQ_TESTS_A, _EQ_TESTS_B, strict=True), ids=_id_func +) def test_minus_offset(a, b): result = b - a expected = a @@ -581,7 +585,7 @@ def test_minus_offset(a, b): @pytest.mark.parametrize( ("a", "b"), - list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) # type: ignore[arg-type] + list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B, strict=True)) # type: ignore[arg-type] + [(YearEnd(month=1), YearEnd(month=2))], ids=_id_func, ) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 45ae049c08b..5879e6beed8 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -628,10 +628,10 @@ def test_cf_timedelta_2d() -> None: @pytest.mark.parametrize( ["deltas", "expected"], [ - (pd.to_timedelta(["1 day", "2 days"]), "days"), - (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), - (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), - (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), + (pd.to_timedelta(["1 day", "2 days"]), "days"), # type: ignore[arg-type, unused-ignore] + (pd.to_timedelta(["1 day", "2 days"]), "days"), # type: ignore[arg-type, unused-ignore] + (pd.to_timedelta(["1 day", "2 days"]), "days"), # type: ignore[arg-type, unused-ignore] + (pd.to_timedelta(["1 day", "2 days"]), "days"), # type: ignore[arg-type, unused-ignore] ], ) def test_infer_timedelta_units(deltas, expected) -> None: diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 1c48dca825d..41ad75b0fea 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -249,7 +249,10 @@ def create_combined_ids(): def _create_combined_ids(shape): tile_ids = _create_tile_ids(shape) nums = range(len(tile_ids)) - return {tile_id: create_test_data(num) for tile_id, num in zip(tile_ids, nums)} + return { + tile_id: create_test_data(num) + for tile_id, num in zip(tile_ids, nums, strict=True) + } def _create_tile_ids(shape): diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index ab3108e7056..3a50a3f1724 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -578,7 +578,7 @@ def func(*x): variables = [xr.Variable("x", a) for a in arrays] data_arrays = [ xr.DataArray(v, {"x": c, "y": ("x", range(len(c)))}) - for v, c in zip(variables, [["a"], ["b", "c"]]) + for v, c in zip(variables, [["a"], ["b", "c"]], strict=True) ] datasets = [xr.Dataset({"data": data_array}) for data_array in data_arrays] @@ -1190,7 +1190,7 @@ def test_apply_dask() -> None: # unknown setting for dask array handling with pytest.raises(ValueError): - apply_ufunc(identity, array, dask="unknown") # type: ignore + apply_ufunc(identity, array, dask="unknown") # type: ignore[arg-type] def dask_safe_identity(x): return apply_ufunc(identity, x, dask="allowed") diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index e0dc105c925..7f7f14c8f16 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -276,7 +276,10 @@ def test_concat_multiple_datasets_missing_vars(include_day: bool) -> None: expected[name][i : i + 1, ...] = np.nan # set up the test data - datasets = [ds.drop_vars(varname) for ds, varname in zip(datasets, vars_to_drop)] + datasets = [ + ds.drop_vars(varname) + for ds, varname in zip(datasets, vars_to_drop, strict=True) + ] actual = concat(datasets, dim="day") @@ -1326,12 +1329,12 @@ def test_concat_preserve_coordinate_order() -> None: actual = concat([ds1, ds2], dim="time") # check dimension order - for act, exp in zip(actual.dims, expected.dims): + for act, exp in zip(actual.dims, expected.dims, strict=True): assert act == exp assert actual.sizes[act] == expected.sizes[exp] # check coordinate order - for act, exp in zip(actual.coords, expected.coords): + for act, exp in zip(actual.coords, expected.coords, strict=True): assert act == exp assert_identical(actual.coords[act], expected.coords[exp]) @@ -1345,12 +1348,12 @@ def test_concat_typing_check() -> None: TypeError, match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", ): - concat([ds, da], dim="foo") # type: ignore + concat([ds, da], dim="foo") # type: ignore[type-var] with pytest.raises( TypeError, match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", ): - concat([da, ds], dim="foo") # type: ignore + concat([da, ds], dim="foo") # type: ignore[type-var] def test_concat_not_all_indexes() -> None: diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index d7291d6abee..e6c69fc1ee1 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -119,7 +119,7 @@ def test_incompatible_attributes(self) -> None: Variable( ["t"], pd.date_range("2000-01-01", periods=3), {"units": "foobar"} ), - Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), + Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), # type: ignore[arg-type, unused-ignore] Variable(["t"], [0, 1, 2], {"add_offset": 0}, {"add_offset": 2}), Variable(["t"], [0, 1, 2], {"_FillValue": 0}, {"_FillValue": 2}), ] diff --git a/xarray/tests/test_coordinates.py b/xarray/tests/test_coordinates.py index f88e554d333..b167332d38b 100644 --- a/xarray/tests/test_coordinates.py +++ b/xarray/tests/test_coordinates.py @@ -64,7 +64,7 @@ def test_init_index_error(self) -> None: Coordinates(indexes={"x": idx}) with pytest.raises(TypeError, match=".* is not an `xarray.indexes.Index`"): - Coordinates(coords={"x": ("x", [1, 2, 3])}, indexes={"x": "not_an_xarray_index"}) # type: ignore + Coordinates(coords={"x": ("x", [1, 2, 3])}, indexes={"x": "not_an_xarray_index"}) # type: ignore[dict-item] def test_init_dim_sizes_conflict(self) -> None: with pytest.raises(ValueError): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 653e6dec43b..062f0525593 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -104,7 +104,8 @@ def test_chunk(self): self.assertLazyAndIdentical(self.eager_var, rechunked) expected_chunksizes = { - dim: chunks for dim, chunks in zip(self.lazy_var.dims, expected) + dim: chunks + for dim, chunks in zip(self.lazy_var.dims, expected, strict=True) } assert rechunked.chunksizes == expected_chunksizes @@ -354,7 +355,8 @@ def test_chunk(self) -> None: self.assertLazyAndIdentical(self.eager_array, rechunked) expected_chunksizes = { - dim: chunks for dim, chunks in zip(self.lazy_array.dims, expected) + dim: chunks + for dim, chunks in zip(self.lazy_array.dims, expected, strict=True) } assert rechunked.chunksizes == expected_chunksizes @@ -362,7 +364,8 @@ def test_chunk(self) -> None: lazy_dataset = self.lazy_array.to_dataset() eager_dataset = self.eager_array.to_dataset() expected_chunksizes = { - dim: chunks for dim, chunks in zip(lazy_dataset.dims, expected) + dim: chunks + for dim, chunks in zip(lazy_dataset.dims, expected, strict=True) } rechunked = lazy_dataset.chunk(chunks) @@ -737,7 +740,7 @@ def test_dataarray_getattr(self): nonindex_coord = build_dask_array("coord") a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) with suppress(AttributeError): - getattr(a, "NOTEXIST") + a.NOTEXIST assert kernel_call_count == 0 def test_dataset_getattr(self): @@ -747,7 +750,7 @@ def test_dataset_getattr(self): nonindex_coord = build_dask_array("coord") ds = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)}) with suppress(AttributeError): - getattr(ds, "NOTEXIST") + ds.NOTEXIST assert kernel_call_count == 0 def test_values(self): @@ -1797,6 +1800,6 @@ def test_minimize_graph_size(): actual = len([key for key in graph if var in key[0]]) # assert that we only include each chunk of an index variable # is only included once, not the product of number of chunks of - # all the other dimenions. + # all the other dimensions. # e.g. previously for 'x', actual == numchunks['y'] * numchunks['z'] assert actual == numchunks[var], (actual, numchunks[var]) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 9feab73d3d1..49df5dcde2d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -283,7 +283,7 @@ def test_sizes(self) -> None: assert array.sizes == {"x": 3, "y": 4} assert tuple(array.sizes) == array.dims with pytest.raises(TypeError): - array.sizes["foo"] = 5 # type: ignore + array.sizes["foo"] = 5 # type: ignore[index] def test_encoding(self) -> None: expected = {"foo": "bar"} @@ -575,8 +575,8 @@ def test_equals_and_identical(self) -> None: def test_equals_failures(self) -> None: orig = DataArray(np.arange(5.0), {"a": 42}, dims="x") assert not orig.equals(np.arange(5)) # type: ignore[arg-type] - assert not orig.identical(123) # type: ignore - assert not orig.broadcast_equals({1: 2}) # type: ignore + assert not orig.identical(123) # type: ignore[arg-type] + assert not orig.broadcast_equals({1: 2}) # type: ignore[arg-type] def test_broadcast_equals(self) -> None: a = DataArray([0, 0], {"y": 0}, dims="x") @@ -889,7 +889,7 @@ def test_chunk(self) -> None: first_dask_name = blocked.data.name with pytest.warns(DeprecationWarning): - blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) # type: ignore + blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) # type: ignore[arg-type] assert blocked.chunks == ((2, 1), (2, 2)) assert blocked.data.name != first_dask_name @@ -2226,7 +2226,7 @@ def from_variables(cls, variables, options): indexed = da.set_xindex("foo", IndexWithOptions, opt=1) assert "foo" in indexed.xindexes - assert getattr(indexed.xindexes["foo"], "opt") == 1 + assert indexed.xindexes["foo"].opt == 1 # type: ignore[attr-defined] def test_dataset_getitem(self) -> None: dv = self.ds["foo"] @@ -2707,7 +2707,7 @@ def test_drop_index_labels(self) -> None: assert_identical(actual, expected) with pytest.warns(DeprecationWarning): - arr.drop([0, 1, 3], dim="y", errors="ignore") # type: ignore + arr.drop([0, 1, 3], dim="y", errors="ignore") # type: ignore[arg-type] def test_drop_index_positions(self) -> None: arr = DataArray(np.random.randn(2, 3), dims=["x", "y"]) @@ -2913,7 +2913,8 @@ def test_reduce_out(self) -> None: @pytest.mark.parametrize("skipna", [True, False, None]) @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) @pytest.mark.parametrize( - "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) + "axis, dim", + zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]], strict=True), ) def test_quantile(self, q, axis, dim, skipna, compute_backend) -> None: va = self.va.copy(deep=True) @@ -4055,7 +4056,7 @@ def test_dot(self) -> None: with pytest.raises(NotImplementedError): da.dot(dm3.to_dataset(name="dm")) with pytest.raises(TypeError): - da.dot(dm3.values) # type: ignore + da.dot(dm3.values) # type: ignore[type-var] def test_dot_align_coords(self) -> None: # GH 3694 @@ -4520,7 +4521,7 @@ def test_query( # test error handling with pytest.raises(ValueError): - aa.query("a > 5") # type: ignore # must be dict or kwargs + aa.query("a > 5") # type: ignore[arg-type] # must be dict or kwargs with pytest.raises(ValueError): aa.query(x=(a > 5)) # must be query string with pytest.raises(UndefinedVariableError): @@ -5342,7 +5343,7 @@ def test_min( minindex = [ x if y is None or ar.dtype.kind == "O" else y - for x, y in zip(minindex, nanindex) + for x, y in zip(minindex, nanindex, strict=True) ] expected2list = [ ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex) @@ -5387,7 +5388,7 @@ def test_max( maxindex = [ x if y is None or ar.dtype.kind == "O" else y - for x, y in zip(maxindex, nanindex) + for x, y in zip(maxindex, nanindex, strict=True) ] expected2list = [ ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex) @@ -5439,7 +5440,7 @@ def test_argmin( minindex = [ x if y is None or ar.dtype.kind == "O" else y - for x, y in zip(minindex, nanindex) + for x, y in zip(minindex, nanindex, strict=True) ] expected2list = [ indarr.isel(y=yi).isel(x=indi, drop=True) @@ -5492,7 +5493,7 @@ def test_argmax( maxindex = [ x if y is None or ar.dtype.kind == "O" else y - for x, y in zip(maxindex, nanindex) + for x, y in zip(maxindex, nanindex, strict=True) ] expected2list = [ indarr.isel(y=yi).isel(x=indi, drop=True) @@ -5588,7 +5589,7 @@ def test_idxmin( # skipna=False minindex3 = [ x if y is None or ar0.dtype.kind == "O" else y - for x, y in zip(minindex0, nanindex) + for x, y in zip(minindex0, nanindex, strict=True) ] expected3list = [ coordarr0.isel(y=yi).isel(x=indi, drop=True) @@ -5730,7 +5731,7 @@ def test_idxmax( # skipna=False maxindex3 = [ x if y is None or ar0.dtype.kind == "O" else y - for x, y in zip(maxindex0, nanindex) + for x, y in zip(maxindex0, nanindex, strict=True) ] expected3list = [ coordarr0.isel(y=yi).isel(x=indi, drop=True) @@ -5830,7 +5831,7 @@ def test_argmin_dim( minindex = [ x if y is None or ar.dtype.kind == "O" else y - for x, y in zip(minindex, nanindex) + for x, y in zip(minindex, nanindex, strict=True) ] expected2list = [ indarr.isel(y=yi).isel(x=indi, drop=True) @@ -5897,7 +5898,7 @@ def test_argmax_dim( maxindex = [ x if y is None or ar.dtype.kind == "O" else y - for x, y in zip(maxindex, nanindex) + for x, y in zip(maxindex, nanindex, strict=True) ] expected2list = [ indarr.isel(y=yi).isel(x=indi, drop=True) @@ -6650,8 +6651,8 @@ def test_to_and_from_iris(self) -> None: ), ) - for coord, orginal_key in zip((actual.coords()), original.coords): - original_coord = original.coords[orginal_key] + for coord, original_key in zip((actual.coords()), original.coords, strict=True): + original_coord = original.coords[original_key] assert coord.var_name == original_coord.name assert_array_equal( coord.points, CFDatetimeCoder().encode(original_coord.variable).values @@ -6726,8 +6727,8 @@ def test_to_and_from_iris_dask(self) -> None: ), ) - for coord, orginal_key in zip((actual.coords()), original.coords): - original_coord = original.coords[orginal_key] + for coord, original_key in zip((actual.coords()), original.coords, strict=True): + original_coord = original.coords[original_key] assert coord.var_name == original_coord.name assert_array_equal( coord.points, CFDatetimeCoder().encode(original_coord.variable).values @@ -7159,7 +7160,7 @@ def test_result_as_expected(self) -> None: def test_error_on_ellipsis_without_list(self) -> None: da = DataArray([[1, 2], [1, 2]], dims=("x", "y")) with pytest.raises(ValueError): - da.stack(flat=...) + da.stack(flat=...) # type: ignore[arg-type] def test_nD_coord_dataarray() -> None: @@ -7202,3 +7203,17 @@ def test_lazy_data_variable_not_loaded(): da = xr.DataArray(v) # No data needs to be accessed, so no error should be raised xr.DataArray(da) + + +def test_unstack_index_var() -> None: + source = xr.DataArray(range(2), dims=["x"], coords=[["a", "b"]]) + da = source.x + da = da.assign_coords(y=("x", ["c", "d"]), z=("x", ["e", "f"])) + da = da.set_index(x=["y", "z"]) + actual = da.unstack("x") + expected = xr.DataArray( + np.array([["a", np.nan], [np.nan, "b"]], dtype=object), + coords={"y": ["c", "d"], "z": ["e", "f"]}, + name="x", + ) + assert_identical(actual, expected) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f2e712e334c..fc2b2251c2c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -582,7 +582,7 @@ def test_constructor_pandas_single(self) -> None: for a in das: pandas_obj = a.to_pandas() - ds_based_on_pandas = Dataset(pandas_obj) # type: ignore # TODO: improve typing of __init__ + ds_based_on_pandas = Dataset(pandas_obj) # type: ignore[arg-type] # TODO: improve typing of __init__ for dim in ds_based_on_pandas.data_vars: assert isinstance(dim, int) assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) @@ -1217,7 +1217,7 @@ def test_chunk_by_frequency(self, freq: str, calendar: str, add_gap: bool) -> No ΔN = 28 time = xr.date_range( "2001-01-01", periods=N + ΔN, freq="D", calendar=calendar - ).to_numpy() + ).to_numpy(copy=True) if add_gap: # introduce an empty bin time[31 : 31 + ΔN] = np.datetime64("NaT") @@ -3212,7 +3212,7 @@ def test_rename_perserve_attrs_encoding(self) -> None: # test propagate attrs/encoding to new variable(s) created from Index object original = Dataset(coords={"x": ("x", [0, 1, 2])}) expected = Dataset(coords={"y": ("y", [0, 1, 2])}) - for ds, dim in zip([original, expected], ["x", "y"]): + for ds, dim in zip([original, expected], ["x", "y"], strict=True): ds[dim].attrs = {"foo": "bar"} ds[dim].encoding = {"foo": "bar"} @@ -3713,7 +3713,7 @@ def test_set_xindex(self) -> None: class NotAnIndex: ... with pytest.raises(TypeError, match=".*not a subclass of xarray.Index"): - ds.set_xindex("foo", NotAnIndex) # type: ignore + ds.set_xindex("foo", NotAnIndex) # type: ignore[arg-type] with pytest.raises(ValueError, match="those variables don't exist"): ds.set_xindex("not_a_coordinate", PandasIndex) @@ -3740,7 +3740,7 @@ def from_variables(cls, variables, options): return cls(options["opt"]) indexed = ds.set_xindex("foo", IndexWithOptions, opt=1) - assert getattr(indexed.xindexes["foo"], "opt") == 1 + assert indexed.xindexes["foo"].opt == 1 # type: ignore[attr-defined] def test_stack(self) -> None: ds = Dataset( @@ -6450,8 +6450,8 @@ def test_full_like(self) -> None: expected = ds.copy(deep=True) # https://github.com/python/mypy/issues/3004 - expected["d1"].values = [2, 2, 2] # type: ignore - expected["d2"].values = [2.0, 2.0, 2.0] # type: ignore + expected["d1"].values = [2, 2, 2] # type: ignore[assignment] + expected["d2"].values = [2.0, 2.0, 2.0] # type: ignore[assignment] assert expected["d1"].dtype == int assert expected["d2"].dtype == float assert_identical(expected, actual) @@ -6459,8 +6459,8 @@ def test_full_like(self) -> None: # override dtype actual = full_like(ds, fill_value=True, dtype=bool) expected = ds.copy(deep=True) - expected["d1"].values = [True, True, True] # type: ignore - expected["d2"].values = [True, True, True] # type: ignore + expected["d1"].values = [True, True, True] # type: ignore[assignment] + expected["d2"].values = [True, True, True] # type: ignore[assignment] assert expected["d1"].dtype == bool assert expected["d2"].dtype == bool assert_identical(expected, actual) @@ -6742,7 +6742,7 @@ def test_pad(self, padded_dim_name, constant_values) -> None: else: np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim) - # check if coord "numbers" with dimention dim3 is paded correctly + # check if coord "numbers" with dimension dim3 is padded correctly if padded_dim_name == "dim3": assert padded["numbers"][[0, -1]].isnull().all() # twarning: passes but dtype changes from int to float @@ -6975,7 +6975,7 @@ def test_query(self, backend, engine, parser) -> None: # test error handling with pytest.raises(ValueError): - ds.query("a > 5") # type: ignore # must be dict or kwargs + ds.query("a > 5") # type: ignore[arg-type] # must be dict or kwargs with pytest.raises(ValueError): ds.query(x=(a > 5)) with pytest.raises(IndexError): @@ -7615,4 +7615,4 @@ def test_transpose_error() -> None: "transpose requires dim to be passed as multiple arguments. Expected `'y', 'x'`. Received `['y', 'x']` instead" ), ): - ds.transpose(["y", "x"]) # type: ignore + ds.transpose(["y", "x"]) # type: ignore[arg-type] diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 9a15376a1f8..6d208a5cf98 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -7,6 +7,7 @@ import pytest import xarray as xr +from xarray import Dataset from xarray.core.datatree import DataTree from xarray.core.datatree_ops import _MAPPED_DOCSTRING_ADDENDUM, insert_doc_addendum from xarray.core.treenode import NotFoundInTreeError @@ -16,14 +17,14 @@ class TestTreeCreation: def test_empty(self): - dt: DataTree = DataTree(name="root") + dt = DataTree(name="root") assert dt.name == "root" assert dt.parent is None assert dt.children == {} assert_identical(dt.to_dataset(), xr.Dataset()) def test_unnamed(self): - dt: DataTree = DataTree() + dt = DataTree() assert dt.name is None def test_bad_names(self): @@ -33,80 +34,104 @@ def test_bad_names(self): with pytest.raises(ValueError): DataTree(name="folder/data") + def test_data_arg(self): + ds = xr.Dataset({"foo": 42}) + tree: DataTree = DataTree(data=ds) + assert_identical(tree.to_dataset(), ds) + + with pytest.raises(TypeError): + DataTree(data=xr.DataArray(42, name="foo")) # type: ignore[arg-type] + class TestFamilyTree: - def test_setparent_unnamed_child_node_fails(self): - john: DataTree = DataTree(name="john") - with pytest.raises(ValueError, match="unnamed"): - DataTree(parent=john) + def test_dont_modify_children_inplace(self): + # GH issue 9196 + child = DataTree() + DataTree(children={"child": child}) + assert child.parent is None def test_create_two_children(self): root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) set1_data = xr.Dataset({"a": 0, "b": 1}) - - root: DataTree = DataTree(data=root_data) - set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=root) - DataTree(name="set2", parent=set1) + root = DataTree.from_dict( + {"/": root_data, "/set1": set1_data, "/set1/set2": None} + ) + assert root["/set1"].name == "set1" + assert root["/set1/set2"].name == "set2" def test_create_full_tree(self, simple_datatree): - root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) - set1_data = xr.Dataset({"a": 0, "b": 1}) - set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) + d = simple_datatree.to_dict() + d_keys = list(d.keys()) - root: DataTree = DataTree(data=root_data) - set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=set1) - DataTree(name="set2", parent=set1) - set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) - DataTree(name="set1", parent=set2) - DataTree(name="set3", parent=root) + expected_keys = [ + "/", + "/set1", + "/set2", + "/set3", + "/set1/set1", + "/set1/set2", + "/set2/set1", + ] - expected = simple_datatree - assert root.identical(expected) + assert d_keys == expected_keys class TestNames: def test_child_gets_named_on_attach(self): - sue: DataTree = DataTree() - mary: DataTree = DataTree(children={"Sue": sue}) # noqa - assert sue.name == "Sue" + sue = DataTree() + mary = DataTree(children={"Sue": sue}) # noqa + assert mary.children["Sue"].name == "Sue" class TestPaths: def test_path_property(self): - sue: DataTree = DataTree() - mary: DataTree = DataTree(children={"Sue": sue}) - john: DataTree = DataTree(children={"Mary": mary}) - assert sue.path == "/Mary/Sue" + john = DataTree.from_dict( + { + "/Mary/Sue": DataTree(), + } + ) + assert john["/Mary/Sue"].path == "/Mary/Sue" assert john.path == "/" def test_path_roundtrip(self): - sue: DataTree = DataTree() - mary: DataTree = DataTree(children={"Sue": sue}) - john: DataTree = DataTree(children={"Mary": mary}) - assert john[sue.path] is sue + john = DataTree.from_dict( + { + "/Mary/Sue": DataTree(), + } + ) + assert john["/Mary/Sue"].name == "Sue" def test_same_tree(self): - mary: DataTree = DataTree() - kate: DataTree = DataTree() - john: DataTree = DataTree(children={"Mary": mary, "Kate": kate}) # noqa - assert mary.same_tree(kate) + john = DataTree.from_dict( + { + "/Mary": DataTree(), + "/Kate": DataTree(), + } + ) + assert john["/Mary"].same_tree(john["/Kate"]) def test_relative_paths(self): - sue: DataTree = DataTree() - mary: DataTree = DataTree(children={"Sue": sue}) - annie: DataTree = DataTree() - john: DataTree = DataTree(children={"Mary": mary, "Annie": annie}) + john = DataTree.from_dict( + { + "/Mary/Sue": DataTree(), + "/Annie": DataTree(), + } + ) + sue_result = john["Mary/Sue"] + if isinstance(sue_result, DataTree): + sue: DataTree = sue_result + + annie_result = john["Annie"] + if isinstance(annie_result, DataTree): + annie: DataTree = annie_result - result = sue.relative_to(john) - assert result == "Mary/Sue" + assert sue.relative_to(john) == "Mary/Sue" assert john.relative_to(sue) == "../.." assert annie.relative_to(sue) == "../../Annie" assert sue.relative_to(annie) == "../Mary/Sue" assert sue.relative_to(sue) == "." - evil_kate: DataTree = DataTree() + evil_kate = DataTree() with pytest.raises( NotFoundInTreeError, match="nodes do not lie within the same tree" ): @@ -116,15 +141,15 @@ def test_relative_paths(self): class TestStoreDatasets: def test_create_with_data(self): dat = xr.Dataset({"a": 0}) - john: DataTree = DataTree(name="john", data=dat) + john = DataTree(name="john", data=dat) assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): - DataTree(name="mary", parent=john, data="junk") # type: ignore[arg-type] + DataTree(name="mary", data="junk") # type: ignore[arg-type] def test_set_data(self): - john: DataTree = DataTree(name="john") + john = DataTree(name="john") dat = xr.Dataset({"a": 0}) john.ds = dat # type: ignore[assignment] @@ -134,17 +159,17 @@ def test_set_data(self): john.ds = "junk" # type: ignore[assignment] def test_has_data(self): - john: DataTree = DataTree(name="john", data=xr.Dataset({"a": 0})) + john = DataTree(name="john", data=xr.Dataset({"a": 0})) assert john.has_data - john_no_data: DataTree = DataTree(name="john", data=None) + john_no_data = DataTree(name="john", data=None) assert not john_no_data.has_data def test_is_hollow(self): - john: DataTree = DataTree(data=xr.Dataset({"a": 0})) + john = DataTree(data=xr.Dataset({"a": 0})) assert john.is_hollow - eve: DataTree = DataTree(children={"john": john}) + eve = DataTree(children={"john": john}) assert eve.is_hollow eve.ds = xr.Dataset({"a": 1}) # type: ignore[assignment] @@ -168,13 +193,21 @@ def test_to_dataset(self): class TestVariablesChildrenNameCollisions: def test_parent_already_has_variable_with_childs_name(self): - dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) with pytest.raises(KeyError, match="already contains a variable named a"): - DataTree(name="a", data=None, parent=dt) + DataTree.from_dict({"/": xr.Dataset({"a": [0], "b": 1}), "/a": None}) + + def test_parent_already_has_variable_with_childs_name_update(self): + dt = DataTree(data=xr.Dataset({"a": [0], "b": 1})) + with pytest.raises(ValueError, match="already contains a variable named a"): + dt.update({"a": DataTree()}) def test_assign_when_already_child_with_variables_name(self): - dt: DataTree = DataTree(data=None) - DataTree(name="a", data=None, parent=dt) + dt = DataTree.from_dict( + { + "/a": DataTree(), + } + ) + with pytest.raises(ValueError, match="node already contains a variable"): dt.ds = xr.Dataset({"a": 0}) # type: ignore[assignment] @@ -190,44 +223,48 @@ class TestGet: ... class TestGetItem: def test_getitem_node(self): - folder1: DataTree = DataTree(name="folder1") - results: DataTree = DataTree(name="results", parent=folder1) - highres: DataTree = DataTree(name="highres", parent=results) - assert folder1["results"] is results - assert folder1["results/highres"] is highres + folder1 = DataTree.from_dict( + { + "/results/highres": DataTree(), + } + ) + + assert folder1["results"].name == "results" + assert folder1["results/highres"].name == "highres" def test_getitem_self(self): - dt: DataTree = DataTree() + dt = DataTree() assert dt["."] is dt def test_getitem_single_data_variable(self): data = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results", data=data) + results = DataTree(name="results", data=data) assert_identical(results["temp"], data["temp"]) def test_getitem_single_data_variable_from_node(self): data = xr.Dataset({"temp": [0, 50]}) - folder1: DataTree = DataTree(name="folder1") - results: DataTree = DataTree(name="results", parent=folder1) - DataTree(name="highres", parent=results, data=data) + folder1 = DataTree.from_dict( + { + "/results/highres": data, + } + ) assert_identical(folder1["results/highres/temp"], data["temp"]) def test_getitem_nonexistent_node(self): - folder1: DataTree = DataTree(name="folder1") - DataTree(name="results", parent=folder1) + folder1 = DataTree.from_dict({"/results": DataTree()}, name="folder1") with pytest.raises(KeyError): folder1["results/highres"] def test_getitem_nonexistent_variable(self): data = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results", data=data) + results = DataTree(name="results", data=data) with pytest.raises(KeyError): results["pressure"] @pytest.mark.xfail(reason="Should be deprecated in favour of .subset") def test_getitem_multiple_data_variables(self): data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) - results: DataTree = DataTree(name="results", data=data) + results = DataTree(name="results", data=data) assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] @pytest.mark.xfail( @@ -235,13 +272,13 @@ def test_getitem_multiple_data_variables(self): ) def test_getitem_dict_like_selection_access_to_dataset(self): data = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results", data=data) + results = DataTree(name="results", data=data) assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] class TestUpdate: def test_update(self): - dt: DataTree = DataTree() + dt = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree()}) expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None}) assert_equal(dt, expected) @@ -249,13 +286,13 @@ def test_update(self): def test_update_new_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1.update({"results": da}) expected = da.rename("results") assert_equal(folder1["results"], expected) def test_update_doesnt_alter_child_name(self): - dt: DataTree = DataTree() + dt = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree(name="b")}) assert "a" in dt.children child = dt["a"] @@ -307,7 +344,9 @@ def test_copy(self, create_test_datatree): for copied in [dt.copy(deep=False), copy(dt)]: assert_identical(dt, copied) - for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + for node, copied_node in zip( + dt.root.subtree, copied.root.subtree, strict=True + ): assert node.encoding == copied_node.encoding # Note: IndexVariable objects with string dtype are always # copied because of xarray.core.util.safe_cast_to_index. @@ -331,6 +370,14 @@ def test_copy_subtree(self): assert_identical(actual, expected) + def test_copy_coord_inheritance(self) -> None: + tree = DataTree.from_dict( + {"/": xr.Dataset(coords={"x": [0, 1]}), "/c": DataTree()} + ) + tree2 = tree.copy() + node_ds = tree2.children["c"].to_dataset(inherited=False) + assert_identical(node_ds, xr.Dataset()) + def test_deepcopy(self, create_test_datatree): dt = create_test_datatree() @@ -340,7 +387,9 @@ def test_deepcopy(self, create_test_datatree): for copied in [dt.copy(deep=True), deepcopy(dt)]: assert_identical(dt, copied) - for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + for node, copied_node in zip( + dt.root.subtree, copied.root.subtree, strict=True + ): assert node.encoding == copied_node.encoding # Note: IndexVariable objects with string dtype are always # copied because of xarray.core.util.safe_cast_to_index. @@ -376,8 +425,8 @@ def test_copy_with_data(self, create_test_datatree): class TestSetItem: def test_setitem_new_child_node(self): - john: DataTree = DataTree(name="john") - mary: DataTree = DataTree(name="mary") + john = DataTree(name="john") + mary = DataTree(name="mary") john["mary"] = mary grafted_mary = john["mary"] @@ -385,38 +434,37 @@ def test_setitem_new_child_node(self): assert grafted_mary.name == "mary" def test_setitem_unnamed_child_node_becomes_named(self): - john2: DataTree = DataTree(name="john2") + john2 = DataTree(name="john2") john2["sonny"] = DataTree() assert john2["sonny"].name == "sonny" def test_setitem_new_grandchild_node(self): - john: DataTree = DataTree(name="john") - mary: DataTree = DataTree(name="mary", parent=john) - rose: DataTree = DataTree(name="rose") - john["mary/rose"] = rose + john = DataTree.from_dict({"/Mary/Rose": DataTree()}) + new_rose = DataTree(data=xr.Dataset({"x": 0})) + john["Mary/Rose"] = new_rose - grafted_rose = john["mary/rose"] - assert grafted_rose.parent is mary - assert grafted_rose.name == "rose" + grafted_rose = john["Mary/Rose"] + assert grafted_rose.parent is john["/Mary"] + assert grafted_rose.name == "Rose" def test_grafted_subtree_retains_name(self): - subtree: DataTree = DataTree(name="original_subtree_name") - root: DataTree = DataTree(name="root") + subtree = DataTree(name="original_subtree_name") + root = DataTree(name="root") root["new_subtree_name"] = subtree # noqa assert subtree.name == "original_subtree_name" def test_setitem_new_empty_node(self): - john: DataTree = DataTree(name="john") + john = DataTree(name="john") john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): - john: DataTree = DataTree(name="john") - mary: DataTree = DataTree(name="mary", parent=john, data=xr.Dataset()) + john = DataTree.from_dict({"/mary": xr.Dataset()}, name="john") + john["mary"] = DataTree() - assert_identical(mary.to_dataset(), xr.Dataset()) + assert_identical(john["mary"].to_dataset(), xr.Dataset()) john.ds = xr.Dataset() # type: ignore[assignment] with pytest.raises(ValueError, match="has no name"): @@ -425,57 +473,57 @@ def test_setitem_overwrite_data_in_node_with_none(self): @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results") + results = DataTree(name="results") results["."] = data assert_identical(results.to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = data assert_identical(folder1["results"].to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results/highres"] = data assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = da expected = da.rename("results") assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self): data = xr.DataArray([0, 50]) - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = data assert_equal(folder1["results"], data) def test_setitem_variable(self): var = xr.Variable(data=[0, 50], dims="x") - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = var assert_equal(folder1["results"], xr.DataArray(var)) def test_setitem_coerce_to_dataarray(self): - folder1: DataTree = DataTree(name="folder1") + folder1 = DataTree(name="folder1") folder1["results"] = 0 assert_equal(folder1["results"], xr.DataArray(0)) def test_setitem_add_new_variable_to_empty_node(self): - results: DataTree = DataTree(name="results") + results = DataTree(name="results") results["pressure"] = xr.DataArray(data=[2, 3]) assert "pressure" in results.ds results["temp"] = xr.Variable(data=[10, 11], dims=["x"]) assert "temp" in results.ds # What if there is a path to traverse first? - results_with_path: DataTree = DataTree(name="results") + results_with_path = DataTree(name="results") results_with_path["highres/pressure"] = xr.DataArray(data=[2, 3]) assert "pressure" in results_with_path["highres"].ds results_with_path["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) @@ -483,14 +531,46 @@ def test_setitem_add_new_variable_to_empty_node(self): def test_setitem_dataarray_replace_existing_node(self): t = xr.Dataset({"temp": [0, 50]}) - results: DataTree = DataTree(name="results", data=t) + results = DataTree(name="results", data=t) p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) assert_identical(results.to_dataset(), expected) -class TestDictionaryInterface: ... +def test_delitem(): + ds = Dataset({"a": 0}, coords={"x": ("x", [1, 2]), "z": "a"}) + dt = DataTree(ds, children={"c": DataTree()}) + + with pytest.raises(KeyError): + del dt["foo"] + + # test delete children + del dt["c"] + assert dt.children == {} + assert set(dt.variables) == {"x", "z", "a"} + with pytest.raises(KeyError): + del dt["c"] + + # test delete variables + del dt["a"] + assert set(dt.coords) == {"x", "z"} + with pytest.raises(KeyError): + del dt["a"] + + # test delete coordinates + del dt["z"] + assert set(dt.coords) == {"x"} + with pytest.raises(KeyError): + del dt["z"] + + # test delete indexed coordinates + del dt["x"] + assert dt.variables == {} + assert dt.coords == {} + assert dt.indexes == {} + with pytest.raises(KeyError): + del dt["x"] class TestTreeFromDict: @@ -540,8 +620,8 @@ def test_full(self, simple_datatree): ] def test_datatree_values(self): - dat1: DataTree = DataTree(data=xr.Dataset({"a": 1})) - expected: DataTree = DataTree() + dat1 = DataTree(data=xr.Dataset({"a": 1})) + expected = DataTree() expected["a"] = dat1 actual = DataTree.from_dict({"a": dat1}) @@ -586,11 +666,16 @@ def test_insertion_order(self): # despite 'Bart' coming before 'Lisa' when sorted alphabetically assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"] + def test_array_values(self): + data = {"foo": xr.DataArray(1, name="bar")} + with pytest.raises(TypeError): + DataTree.from_dict(data) # type: ignore[arg-type] + class TestDatasetView: def test_view_contents(self): ds = create_test_data() - dt: DataTree = DataTree(data=ds) + dt = DataTree(data=ds) assert ds.identical( dt.ds ) # this only works because Dataset.identical doesn't check types @@ -598,8 +683,13 @@ def test_view_contents(self): def test_immutability(self): # See issue https://github.com/xarray-contrib/datatree/issues/38 - dt: DataTree = DataTree(name="root", data=None) - DataTree(name="a", data=None, parent=dt) + dt = DataTree.from_dict( + { + "/": None, + "/a": None, + }, + name="root", + ) with pytest.raises( AttributeError, match="Mutation of the DatasetView is not allowed" @@ -616,7 +706,7 @@ def test_immutability(self): def test_methods(self): ds = create_test_data() - dt: DataTree = DataTree(data=ds) + dt = DataTree(data=ds) assert ds.mean().identical(dt.ds.mean()) assert isinstance(dt.ds.mean(), xr.Dataset) @@ -637,7 +727,7 @@ def test_init_via_type(self): dims=["x", "y", "time"], coords={"area": (["x", "y"], np.random.rand(3, 4))}, ).to_dataset(name="data") - dt: DataTree = DataTree(data=a) + dt = DataTree(data=a) def weighted_mean(ds): return ds.weighted(ds.area).mean(["x", "y"]) @@ -655,7 +745,7 @@ def test_attribute_access(self, create_test_datatree): assert key in dir(dt) # dims - assert_equal(dt["a"]["y"], getattr(dt.a, "y")) + assert_equal(dt["a"]["y"], dt.a.y) assert "y" in dir(dt["a"]) # children @@ -688,7 +778,7 @@ def test_operation_with_attrs_but_no_data(self): class TestRepr: def test_repr(self): - dt: DataTree = DataTree.from_dict( + dt = DataTree.from_dict( { "/": xr.Dataset( {"e": (("x",), [1.0, 2.0])}, @@ -827,12 +917,12 @@ def test_inconsistent_dims(self): } ) - dt: DataTree = DataTree() + dt = DataTree() dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) with pytest.raises(ValueError, match=expected_msg): dt["/b/c"] = xr.DataArray([3.0], dims=["x"]) - b: DataTree = DataTree(data=xr.Dataset({"c": (("x",), [3.0])})) + b = DataTree(data=xr.Dataset({"c": (("x",), [3.0])})) with pytest.raises(ValueError, match=expected_msg): DataTree( data=xr.Dataset({"a": (("x",), [1.0, 2.0])}), @@ -864,13 +954,13 @@ def test_inconsistent_child_indexes(self): } ) - dt: DataTree = DataTree() - dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore + dt = DataTree() + dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore[assignment] dt["/b"] = DataTree() with pytest.raises(ValueError, match=expected_msg): dt["/b"].ds = xr.Dataset(coords={"x": [2.0]}) - b: DataTree = DataTree(xr.Dataset(coords={"x": [2.0]})) + b = DataTree(xr.Dataset(coords={"x": [2.0]})) with pytest.raises(ValueError, match=expected_msg): DataTree(data=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) @@ -899,14 +989,14 @@ def test_inconsistent_grandchild_indexes(self): } ) - dt: DataTree = DataTree() - dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore + dt = DataTree() + dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore[assignment] dt["/b/c"] = DataTree() with pytest.raises(ValueError, match=expected_msg): dt["/b/c"].ds = xr.Dataset(coords={"x": [2.0]}) - c: DataTree = DataTree(xr.Dataset(coords={"x": [2.0]})) - b: DataTree = DataTree(children={"c": c}) + c = DataTree(xr.Dataset(coords={"x": [2.0]})) + b = DataTree(children={"c": c}) with pytest.raises(ValueError, match=expected_msg): DataTree(data=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) @@ -933,7 +1023,7 @@ def test_inconsistent_grandchild_dims(self): } ) - dt: DataTree = DataTree() + dt = DataTree() dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) with pytest.raises(ValueError, match=expected_msg): dt["/b/c/d"] = xr.DataArray([3.0], dims=["x"]) @@ -961,7 +1051,7 @@ def test_drop_nodes(self): assert childless.children == {} def test_assign(self): - dt: DataTree = DataTree() + dt = DataTree() expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "/a": None}) # kwargs form @@ -1008,7 +1098,7 @@ def f(x, tree, y): class TestSubset: def test_match(self): # TODO is this example going to cause problems with case sensitivity? - dt: DataTree = DataTree.from_dict( + dt = DataTree.from_dict( { "/a/A": None, "/a/B": None, @@ -1026,8 +1116,8 @@ def test_match(self): assert_identical(result, expected) def test_filter(self): - simpsons: DataTree = DataTree.from_dict( - d={ + simpsons = DataTree.from_dict( + { "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), "/Homer": xr.Dataset({"age": 39}), @@ -1038,7 +1128,7 @@ def test_filter(self): name="Abe", ) expected = DataTree.from_dict( - d={ + { "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), "/Homer": xr.Dataset({"age": 39}), @@ -1052,44 +1142,51 @@ def test_filter(self): class TestDSMethodInheritance: def test_dataset_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt: DataTree = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) + dt = DataTree.from_dict( + { + "/": ds, + "/results": ds, + } + ) - expected: DataTree = DataTree(data=ds.isel(x=1)) - DataTree(name="results", parent=expected, data=ds.isel(x=1)) + expected = DataTree.from_dict( + { + "/": ds.isel(x=1), + "/results": ds.isel(x=1), + } + ) result = dt.isel(x=1) assert_equal(result, expected) def test_reduce_method(self): ds = xr.Dataset({"a": ("x", [False, True, False])}) - dt: DataTree = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) + dt = DataTree.from_dict({"/": ds, "/results": ds}) - expected: DataTree = DataTree(data=ds.any()) - DataTree(name="results", parent=expected, data=ds.any()) + expected = DataTree.from_dict({"/": ds.any(), "/results": ds.any()}) result = dt.any() assert_equal(result, expected) def test_nan_reduce_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt: DataTree = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) + dt = DataTree.from_dict({"/": ds, "/results": ds}) - expected: DataTree = DataTree(data=ds.mean()) - DataTree(name="results", parent=expected, data=ds.mean()) + expected = DataTree.from_dict({"/": ds.mean(), "/results": ds.mean()}) result = dt.mean() assert_equal(result, expected) def test_cum_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt: DataTree = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) + dt = DataTree.from_dict({"/": ds, "/results": ds}) - expected: DataTree = DataTree(data=ds.cumsum()) - DataTree(name="results", parent=expected, data=ds.cumsum()) + expected = DataTree.from_dict( + { + "/": ds.cumsum(), + "/results": ds.cumsum(), + } + ) result = dt.cumsum() assert_equal(result, expected) @@ -1099,11 +1196,9 @@ class TestOps: def test_binary_op_on_int(self): ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt: DataTree = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) + dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) - expected: DataTree = DataTree(data=ds1 * 5) - DataTree(name="subnode", data=ds2 * 5, parent=expected) + expected = DataTree.from_dict({"/": ds1 * 5, "/subnode": ds2 * 5}) # TODO: Remove ignore when ops.py is migrated? result: DataTree = dt * 5 # type: ignore[assignment,operator] @@ -1112,12 +1207,21 @@ def test_binary_op_on_int(self): def test_binary_op_on_dataset(self): ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt: DataTree = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) + dt = DataTree.from_dict( + { + "/": ds1, + "/subnode": ds2, + } + ) + other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) - expected: DataTree = DataTree(data=ds1 * other_ds) - DataTree(name="subnode", data=ds2 * other_ds, parent=expected) + expected = DataTree.from_dict( + { + "/": ds1 * other_ds, + "/subnode": ds2 * other_ds, + } + ) result = dt * other_ds assert_equal(result, expected) @@ -1125,14 +1229,13 @@ def test_binary_op_on_dataset(self): def test_binary_op_on_datatree(self): ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt: DataTree = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) - expected: DataTree = DataTree(data=ds1 * ds1) - DataTree(name="subnode", data=ds2 * ds2, parent=expected) + dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) + + expected = DataTree.from_dict({"/": ds1 * ds1, "/subnode": ds2 * ds2}) # TODO: Remove ignore when ops.py is migrated? - result: DataTree = dt * dt # type: ignore[operator] + result = dt * dt # type: ignore[operator] assert_equal(result, expected) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index b8b55613c4a..1d2595cd013 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -19,8 +19,8 @@ def test_not_a_tree(self): check_isomorphic("s", 1) # type: ignore[arg-type] def test_different_widths(self): - dt1 = DataTree.from_dict(d={"a": empty}) - dt2 = DataTree.from_dict(d={"b": empty, "c": empty}) + dt1 = DataTree.from_dict({"a": empty}) + dt2 = DataTree.from_dict({"b": empty, "c": empty}) expected_err_str = ( "Number of children on node '/' of the left object: 1\n" "Number of children on node '/' of the right object: 2" @@ -74,10 +74,9 @@ def test_checking_from_root(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() real_root: DataTree = DataTree(name="real root") - dt2.name = "not_real_root" - dt2.parent = real_root + real_root["not_real_root"] = dt2 with pytest.raises(TreeIsomorphismError): - check_isomorphic(dt1, dt2, check_from_root=True) + check_isomorphic(dt1, real_root, check_from_root=True) class TestMapOverSubTree: @@ -321,7 +320,7 @@ def weighted_mean(ds): def test_alter_inplace_forbidden(self): simpsons = DataTree.from_dict( - d={ + { "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), "/Homer": xr.Dataset({"age": 39}), diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 3bbae55b105..da263f1b30e 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -112,7 +112,9 @@ def test_first(self): array([[8, 5, 2, nan], [nan, 13, 14, 15]]), array([[2, 5, 8], [13, 17, 21]]), ] - for axis, expected in zip([0, 1, 2, -3, -2, -1], 2 * expected_results): + for axis, expected in zip( + [0, 1, 2, -3, -2, -1], 2 * expected_results, strict=True + ): actual = first(self.x, axis) assert_array_equal(expected, actual) @@ -133,7 +135,9 @@ def test_last(self): array([[8, 9, 10, nan], [nan, 21, 18, 15]]), array([[2, 6, 10], [15, 18, 21]]), ] - for axis, expected in zip([0, 1, 2, -3, -2, -1], 2 * expected_results): + for axis, expected in zip( + [0, 1, 2, -3, -2, -1], 2 * expected_results, strict=True + ): actual = last(self.x, axis) assert_array_equal(expected, actual) diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 7cfffd68620..92df269cb4f 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -51,12 +51,12 @@ def foo(self): # check descriptor assert ds.demo.__doc__ == "Demo accessor." # TODO: typing doesn't seem to work with accessors - assert xr.Dataset.demo.__doc__ == "Demo accessor." # type: ignore + assert xr.Dataset.demo.__doc__ == "Demo accessor." # type: ignore[attr-defined] assert isinstance(ds.demo, DemoAccessor) - assert xr.Dataset.demo is DemoAccessor # type: ignore + assert xr.Dataset.demo is DemoAccessor # type: ignore[attr-defined] # ensure we can remove it - del xr.Dataset.demo # type: ignore + del xr.Dataset.demo # type: ignore[attr-defined] assert not hasattr(xr.Dataset, "demo") with pytest.warns(Warning, match="overriding a preexisting attribute"): diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 9d0eb81bace..a2fef9d9b6b 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -118,9 +118,9 @@ def test_format_items(self) -> None: np.arange(4) * np.timedelta64(500, "ms"), "00:00:00 00:00:00.500000 00:00:01 00:00:01.500000", ), - (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), + (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), # type: ignore[arg-type, unused-ignore] ( - pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), + pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), # type: ignore[arg-type, unused-ignore] "1 days 01:00:00 1 days 00:00:00 0 days 00:00:00", ), ([1, 2, 3], "1 2 3"), @@ -615,7 +615,7 @@ def test_array_scalar_format(self) -> None: # Test numpy arrays raises: var = xr.DataArray([0.1, 0.2]) - with pytest.raises(NotImplementedError) as excinfo: # type: ignore + with pytest.raises(NotImplementedError) as excinfo: # type: ignore[assignment] format(var, ".2f") assert "Using format_spec is only supported" in str(excinfo.value) @@ -652,13 +652,18 @@ def test_datatree_print_node_with_data(self): "Data variables", "*empty*", ] - for expected_line, printed_line in zip(expected, printout.splitlines()): + for expected_line, printed_line in zip( + expected, printout.splitlines(), strict=True + ): assert expected_line in printed_line def test_datatree_printout_nested_node(self): dat = xr.Dataset({"a": [0, 2]}) - root: DataTree = DataTree(name="root") - DataTree(name="results", data=dat, parent=root) + root = DataTree.from_dict( + { + "/results": dat, + } + ) printout = str(root) assert printout.splitlines()[3].startswith(" ") @@ -841,7 +846,7 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: attrs = {k: 2 for k in b} coords = {_c: np.array([0, 1], dtype=np.uint64) for _c in c} data_vars = dict() - for v, _c in zip(a, coords.items()): + for v, _c in zip(a, coords.items(), strict=True): data_vars[v] = xr.DataArray( name=v, data=np.array([3, 4], dtype=np.uint64), diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index fc04b49fabc..fa6172c5d66 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import operator import warnings from unittest import mock @@ -13,7 +14,7 @@ from xarray import DataArray, Dataset, Variable from xarray.core.alignment import broadcast from xarray.core.groupby import _consolidate_slices -from xarray.core.types import InterpOptions +from xarray.core.types import InterpOptions, ResampleCompatible from xarray.groupers import ( BinGrouper, EncodedGroups, @@ -29,7 +30,7 @@ create_test_data, has_cftime, has_flox, - has_pandas_ge_2_1, + has_pandas_ge_2_2, requires_cftime, requires_dask, requires_flox, @@ -135,7 +136,7 @@ def test_multi_index_groupby_sum() -> None: ) assert_equal(expected, ds) - if not has_pandas_ge_2_1: + if not has_pandas_ge_2_2: # the next line triggers a mysterious multiindex error on pandas 2.0 return @@ -583,7 +584,7 @@ def test_groupby_repr(obj, dim) -> None: N = len(np.unique(obj[dim])) expected = f"<{obj.__class__.__name__}GroupBy" expected += f", grouped over 1 grouper(s), {N} groups in total:" - expected += f"\n\t{dim!r}: {N} groups with labels " + expected += f"\n {dim!r}: {N} groups with labels " if dim == "x": expected += "1, 2, 3, 4, 5>" elif dim == "y": @@ -600,7 +601,7 @@ def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"<{obj.__class__.__name__}GroupBy" expected += ", grouped over 1 grouper(s), 12 groups in total:\n" - expected += "\t'month': 12 groups with labels " + expected += " 'month': 12 groups with labels " expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected @@ -757,7 +758,6 @@ def test_groupby_none_group_name() -> None: def test_groupby_getitem(dataset) -> None: - assert_identical(dataset.sel(x=["a"]), dataset.groupby("x")["a"]) assert_identical(dataset.sel(z=[1]), dataset.groupby("z")[1]) assert_identical(dataset.foo.sel(x=["a"]), dataset.foo.groupby("x")["a"]) @@ -788,7 +788,7 @@ def test_groupby_dataset() -> None: ("b", data.isel(x=[1])), ("c", data.isel(x=[2])), ] - for actual1, expected1 in zip(groupby, expected_items): + for actual1, expected1 in zip(groupby, expected_items, strict=True): assert actual1[0] == expected1[0] assert_equal(actual1[1], expected1[1]) @@ -1235,12 +1235,12 @@ def test_stack_groupby_unsorted_coord(self) -> None: def test_groupby_iter(self) -> None: for (act_x, act_dv), (exp_x, exp_ds) in zip( - self.dv.groupby("y"), self.ds.groupby("y") + self.dv.groupby("y"), self.ds.groupby("y"), strict=True ): assert exp_x == act_x assert_identical(exp_ds["foo"], act_dv) for (_, exp_dv), (_, act_dv) in zip( - self.dv.groupby("x"), self.dv.groupby("x") + self.dv.groupby("x"), self.dv.groupby("x"), strict=True ): assert_identical(exp_dv, act_dv) @@ -1706,7 +1706,7 @@ def test_groupby_bins_multidim(self) -> None: bincoord = np.array( [ pd.Interval(left, right, closed="right") - for left, right in zip(bins[:-1], bins[1:]) + for left, right in zip(bins[:-1], bins[1:], strict=True) ], dtype=object, ) @@ -1773,7 +1773,21 @@ def test_groupby_fastpath_for_monotonic(self, use_flox: bool) -> None: class TestDataArrayResample: @pytest.mark.parametrize("use_cftime", [True, False]) - def test_resample(self, use_cftime: bool) -> None: + @pytest.mark.parametrize( + "resample_freq", + [ + "24h", + "123456s", + "1234567890us", + pd.Timedelta(hours=2), + pd.offsets.MonthBegin(), + pd.offsets.Second(123456), + datetime.timedelta(days=1, hours=6), + ], + ) + def test_resample( + self, use_cftime: bool, resample_freq: ResampleCompatible + ) -> None: if use_cftime and not has_cftime: pytest.skip() times = xr.date_range( @@ -1795,23 +1809,23 @@ def resample_as_pandas(array, *args, **kwargs): array = DataArray(np.arange(10), [("time", times)]) - actual = array.resample(time="24h").mean() - expected = resample_as_pandas(array, "24h") + actual = array.resample(time=resample_freq).mean() + expected = resample_as_pandas(array, resample_freq) assert_identical(expected, actual) - actual = array.resample(time="24h").reduce(np.mean) + actual = array.resample(time=resample_freq).reduce(np.mean) assert_identical(expected, actual) - actual = array.resample(time="24h", closed="right").mean() - expected = resample_as_pandas(array, "24h", closed="right") + actual = array.resample(time=resample_freq, closed="right").mean() + expected = resample_as_pandas(array, resample_freq, closed="right") assert_identical(expected, actual) with pytest.raises(ValueError, match=r"Index must be monotonic"): - array[[2, 0, 1]].resample(time="1D") + array[[2, 0, 1]].resample(time=resample_freq) reverse = array.isel(time=slice(-1, None, -1)) with pytest.raises(ValueError): - reverse.resample(time="1D").mean() + reverse.resample(time=resample_freq).mean() @pytest.mark.parametrize("use_cftime", [True, False]) def test_resample_doctest(self, use_cftime: bool) -> None: @@ -2206,6 +2220,67 @@ def test_resample_origin(self) -> None: class TestDatasetResample: + @pytest.mark.parametrize("use_cftime", [True, False]) + @pytest.mark.parametrize( + "resample_freq", + [ + "24h", + "123456s", + "1234567890us", + pd.Timedelta(hours=2), + pd.offsets.MonthBegin(), + pd.offsets.Second(123456), + datetime.timedelta(days=1, hours=6), + ], + ) + def test_resample( + self, use_cftime: bool, resample_freq: ResampleCompatible + ) -> None: + if use_cftime and not has_cftime: + pytest.skip() + times = xr.date_range( + "2000-01-01", freq="6h", periods=10, use_cftime=use_cftime + ) + + def resample_as_pandas(ds, *args, **kwargs): + ds_ = ds.copy(deep=True) + if use_cftime: + ds_["time"] = times.to_datetimeindex() + result = Dataset.from_dataframe( + ds_.to_dataframe().resample(*args, **kwargs).mean() + ) + if use_cftime: + result = result.convert_calendar( + calendar="standard", use_cftime=use_cftime + ) + return result + + ds = Dataset( + { + "foo": ("time", np.random.randint(1, 1000, 10)), + "bar": ("time", np.random.randint(1, 1000, 10)), + "time": times, + } + ) + + actual = ds.resample(time=resample_freq).mean() + expected = resample_as_pandas(ds, resample_freq) + assert_identical(expected, actual) + + actual = ds.resample(time=resample_freq).reduce(np.mean) + assert_identical(expected, actual) + + actual = ds.resample(time=resample_freq, closed="right").mean() + expected = resample_as_pandas(ds, resample_freq, closed="right") + assert_identical(expected, actual) + + with pytest.raises(ValueError, match=r"Index must be monotonic"): + ds.isel(time=[2, 0, 1]).resample(time=resample_freq) + + reverse = ds.isel(time=slice(-1, None, -1)) + with pytest.raises(ValueError): + reverse.resample(time=resample_freq).mean() + def test_resample_and_first(self) -> None: times = pd.date_range("2000-01-01", freq="6h", periods=10) ds = Dataset( @@ -2635,6 +2710,36 @@ def test_weather_data_resample(use_flox): assert expected.location.attrs == ds.location.attrs +@pytest.mark.parametrize("as_dataset", [True, False]) +def test_multiple_groupers_string(as_dataset) -> None: + obj = DataArray( + np.array([1, 2, 3, 0, 2, np.nan]), + dims="d", + coords=dict( + labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])), + labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])), + ), + name="foo", + ) + + if as_dataset: + obj = obj.to_dataset() # type: ignore[assignment] + + expected = obj.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()).mean() + actual = obj.groupby(("labels1", "labels2")).mean() + assert_identical(expected, actual) + + # Passes `"labels2"` to squeeze; will raise an error around kwargs rather than the + # warning & type error in the future + with pytest.warns(FutureWarning): + with pytest.raises(TypeError): + obj.groupby("labels1", "labels2") # type: ignore[arg-type, misc] + with pytest.raises(ValueError): + obj.groupby("labels1", foo="bar") # type: ignore[arg-type] + with pytest.raises(ValueError): + obj.groupby("labels1", foo=UniqueGrouper()) + + @pytest.mark.parametrize("use_flox", [True, False]) def test_multiple_groupers(use_flox) -> None: da = DataArray( diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 48e254b037b..cf14e5c8f43 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -674,13 +674,15 @@ def test_copy_indexes(self, indexes) -> None: copied, index_vars = indexes.copy_indexes() assert copied.keys() == indexes.keys() - for new, original in zip(copied.values(), indexes.values()): + for new, original in zip(copied.values(), indexes.values(), strict=True): assert new.equals(original) # check unique index objects preserved assert copied["z"] is copied["one"] is copied["two"] assert index_vars.keys() == indexes.variables.keys() - for new, original in zip(index_vars.values(), indexes.variables.values()): + for new, original in zip( + index_vars.values(), indexes.variables.values(), strict=True + ): assert_identical(new, original) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 7151c669fbc..5c03881242b 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -406,7 +406,7 @@ def test_errors(use_dask: bool) -> None: for method in ["akima", "spline"]: with pytest.raises(ValueError): - da.interp(x=[0.5, 1.5], method=method) # type: ignore + da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type] # not sorted if use_dask: @@ -421,9 +421,9 @@ def test_errors(use_dask: bool) -> None: # invalid method with pytest.raises(ValueError): - da.interp(x=[2, 0], method="boo") # type: ignore + da.interp(x=[2, 0], method="boo") # type: ignore[arg-type] with pytest.raises(ValueError): - da.interp(y=[2, 0], method="boo") # type: ignore + da.interp(y=[2, 0], method="boo") # type: ignore[arg-type] # object-type DataArray cannot be interpolated da = xr.DataArray(["a", "b", "c"], dims="x", coords={"x": [0, 1, 2]}) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index bd75f633b82..bf90074a7cc 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -600,9 +600,8 @@ def test_get_clean_interp_index_cf_calendar(cf_da, calendar): @requires_cftime -@pytest.mark.parametrize( - ("calendar", "freq"), zip(["gregorian", "proleptic_gregorian"], ["1D", "1ME", "1Y"]) -) +@pytest.mark.parametrize("calendar", ["gregorian", "proleptic_gregorian"]) +@pytest.mark.parametrize("freq", ["1D", "1ME", "1YE"]) def test_get_clean_interp_index_dt(cf_da, calendar, freq): """In the gregorian case, the index should be proportional to normal datetimes.""" g = cf_da(calendar, freq=freq) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 3ebaeff712b..3d47f3e1803 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -224,16 +224,16 @@ def test_label_from_attrs(self) -> None: assert label_from_attrs(da) == long_latex_name def test1d(self) -> None: - self.darray[:, 0, 0].plot() + self.darray[:, 0, 0].plot() # type: ignore[call-arg] with pytest.raises(ValueError, match=r"x must be one of None, 'dim_0'"): - self.darray[:, 0, 0].plot(x="dim_1") + self.darray[:, 0, 0].plot(x="dim_1") # type: ignore[call-arg] with pytest.raises(TypeError, match=r"complex128"): - (self.darray[:, 0, 0] + 1j).plot() + (self.darray[:, 0, 0] + 1j).plot() # type: ignore[call-arg] def test_1d_bool(self) -> None: - xr.ones_like(self.darray[:, 0, 0], dtype=bool).plot() + xr.ones_like(self.darray[:, 0, 0], dtype=bool).plot() # type: ignore[call-arg] def test_1d_x_y_kw(self) -> None: z = np.arange(10) @@ -243,17 +243,17 @@ def test_1d_x_y_kw(self) -> None: f, axs = plt.subplots(3, 1, squeeze=False) for aa, (x, y) in enumerate(xy): - da.plot(x=x, y=y, ax=axs.flat[aa]) + da.plot(x=x, y=y, ax=axs.flat[aa]) # type: ignore[call-arg] with pytest.raises(ValueError, match=r"Cannot specify both"): - da.plot(x="z", y="z") + da.plot(x="z", y="z") # type: ignore[call-arg] error_msg = "must be one of None, 'z'" with pytest.raises(ValueError, match=rf"x {error_msg}"): - da.plot(x="f") + da.plot(x="f") # type: ignore[call-arg] with pytest.raises(ValueError, match=rf"y {error_msg}"): - da.plot(y="f") + da.plot(y="f") # type: ignore[call-arg] def test_multiindex_level_as_coord(self) -> None: da = xr.DataArray( @@ -264,11 +264,11 @@ def test_multiindex_level_as_coord(self) -> None: da = da.set_index(x=["a", "b"]) for x in ["a", "b"]: - h = da.plot(x=x)[0] + h = da.plot(x=x)[0] # type: ignore[call-arg] assert_array_equal(h.get_xdata(), da[x].values) for y in ["a", "b"]: - h = da.plot(y=y)[0] + h = da.plot(y=y)[0] # type: ignore[call-arg] assert_array_equal(h.get_ydata(), da[y].values) # Test for bug in GH issue #2725 @@ -302,10 +302,10 @@ def test_line_plot_along_1d_coord(self) -> None: coords={"x": x_coord, "time": t_coord}, ) - line = da.plot(x="time", hue="x")[0] + line = da.plot(x="time", hue="x")[0] # type: ignore[call-arg] assert_array_equal(line.get_xdata(), da.coords["time"].values) - line = da.plot(y="time", hue="x")[0] + line = da.plot(y="time", hue="x")[0] # type: ignore[call-arg] assert_array_equal(line.get_ydata(), da.coords["time"].values) def test_line_plot_wrong_hue(self) -> None: @@ -315,7 +315,7 @@ def test_line_plot_wrong_hue(self) -> None: ) with pytest.raises(ValueError, match="hue must be one of"): - da.plot(x="t", hue="wrong_coord") + da.plot(x="t", hue="wrong_coord") # type: ignore[call-arg] def test_2d_line(self) -> None: with pytest.raises(ValueError, match=r"hue"): @@ -387,7 +387,7 @@ def test_2d_coord_line_plot_coords_transpose_invariant(self) -> None: def test_2d_before_squeeze(self) -> None: a = DataArray(easy_array((1, 5))) - a.plot() + a.plot() # type: ignore[call-arg] def test2d_uniform_calls_imshow(self) -> None: assert self.imshow_called(self.darray[:, :, 0].plot.imshow) @@ -517,7 +517,7 @@ def test_contourf_cmap_set_with_bad_under_over(self) -> None: assert cmap(np.inf) == cmap_expected(np.inf) def test3d(self) -> None: - self.darray.plot() + self.darray.plot() # type: ignore[call-arg] def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot) @@ -605,10 +605,10 @@ def test_geo_data(self) -> None: dims=("y", "x"), coords={"lon": (("y", "x"), lon), "lat": (("y", "x"), lat)}, ) - da.plot(x="lon", y="lat") + da.plot(x="lon", y="lat") # type: ignore[call-arg] ax = plt.gca() assert ax.has_data() - da.plot(x="lat", y="lon") + da.plot(x="lat", y="lon") # type: ignore[call-arg] ax = plt.gca() assert ax.has_data() @@ -619,7 +619,7 @@ def test_datetime_dimension(self) -> None: a = DataArray( easy_array((nrow, ncol)), coords=[("time", time), ("y", range(ncol))] ) - a.plot() + a.plot() # type: ignore[call-arg] ax = plt.gca() assert ax.has_data() @@ -631,7 +631,7 @@ def test_date_dimension(self) -> None: a = DataArray( easy_array((nrow, ncol)), coords=[("time", time), ("y", range(ncol))] ) - a.plot() + a.plot() # type: ignore[call-arg] ax = plt.gca() assert ax.has_data() @@ -641,24 +641,24 @@ def test_convenient_facetgrid(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"]) d.coords["z"] = list("abcd") - g = d.plot(x="x", y="y", col="z", col_wrap=2, cmap="cool") + g = d.plot(x="x", y="y", col="z", col_wrap=2, cmap="cool") # type: ignore[call-arg] assert_array_equal(g.axs.shape, [2, 2]) for ax in g.axs.flat: assert ax.has_data() with pytest.raises(ValueError, match=r"[Ff]acet"): - d.plot(x="x", y="y", col="z", ax=plt.gca()) + d.plot(x="x", y="y", col="z", ax=plt.gca()) # type: ignore[call-arg] with pytest.raises(ValueError, match=r"[Ff]acet"): - d[0].plot(x="x", y="y", col="z", ax=plt.gca()) + d[0].plot(x="x", y="y", col="z", ax=plt.gca()) # type: ignore[call-arg] @pytest.mark.slow def test_subplot_kws(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"]) d.coords["z"] = list("abcd") - g = d.plot( + g = d.plot( # type: ignore[call-arg] x="x", y="y", col="z", @@ -672,58 +672,58 @@ def test_subplot_kws(self) -> None: @pytest.mark.slow def test_plot_size(self) -> None: - self.darray[:, 0, 0].plot(figsize=(13, 5)) + self.darray[:, 0, 0].plot(figsize=(13, 5)) # type: ignore[call-arg] assert tuple(plt.gcf().get_size_inches()) == (13, 5) - self.darray.plot(figsize=(13, 5)) + self.darray.plot(figsize=(13, 5)) # type: ignore[call-arg] assert tuple(plt.gcf().get_size_inches()) == (13, 5) - self.darray.plot(size=5) + self.darray.plot(size=5) # type: ignore[call-arg] assert plt.gcf().get_size_inches()[1] == 5 - self.darray.plot(size=5, aspect=2) + self.darray.plot(size=5, aspect=2) # type: ignore[call-arg] assert tuple(plt.gcf().get_size_inches()) == (10, 5) with pytest.raises(ValueError, match=r"cannot provide both"): - self.darray.plot(ax=plt.gca(), figsize=(3, 4)) + self.darray.plot(ax=plt.gca(), figsize=(3, 4)) # type: ignore[call-arg] with pytest.raises(ValueError, match=r"cannot provide both"): - self.darray.plot(size=5, figsize=(3, 4)) + self.darray.plot(size=5, figsize=(3, 4)) # type: ignore[call-arg] with pytest.raises(ValueError, match=r"cannot provide both"): - self.darray.plot(size=5, ax=plt.gca()) + self.darray.plot(size=5, ax=plt.gca()) # type: ignore[call-arg] with pytest.raises(ValueError, match=r"cannot provide `aspect`"): - self.darray.plot(aspect=1) + self.darray.plot(aspect=1) # type: ignore[call-arg] @pytest.mark.slow @pytest.mark.filterwarnings("ignore:tight_layout cannot") def test_convenient_facetgrid_4d(self) -> None: a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=["y", "x", "columns", "rows"]) - g = d.plot(x="x", y="y", col="columns", row="rows") + g = d.plot(x="x", y="y", col="columns", row="rows") # type: ignore[call-arg] assert_array_equal(g.axs.shape, [3, 2]) for ax in g.axs.flat: assert ax.has_data() with pytest.raises(ValueError, match=r"[Ff]acet"): - d.plot(x="x", y="y", col="columns", ax=plt.gca()) + d.plot(x="x", y="y", col="columns", ax=plt.gca()) # type: ignore[call-arg] def test_coord_with_interval(self) -> None: """Test line plot with intervals.""" bins = [-1, 0, 1, 2] - self.darray.groupby_bins("dim_0", bins).mean(...).plot() + self.darray.groupby_bins("dim_0", bins).mean(...).plot() # type: ignore[call-arg] def test_coord_with_interval_x(self) -> None: """Test line plot with intervals explicitly on x axis.""" bins = [-1, 0, 1, 2] - self.darray.groupby_bins("dim_0", bins).mean(...).plot(x="dim_0_bins") + self.darray.groupby_bins("dim_0", bins).mean(...).plot(x="dim_0_bins") # type: ignore[call-arg] def test_coord_with_interval_y(self) -> None: """Test line plot with intervals explicitly on y axis.""" bins = [-1, 0, 1, 2] - self.darray.groupby_bins("dim_0", bins).mean(...).plot(y="dim_0_bins") + self.darray.groupby_bins("dim_0", bins).mean(...).plot(y="dim_0_bins") # type: ignore[call-arg] def test_coord_with_interval_xy(self) -> None: """Test line plot with intervals on both x and y axes.""" @@ -737,7 +737,7 @@ def test_labels_with_units_with_interval(self, dim) -> None: arr = self.darray.groupby_bins("dim_0", bins).mean(...) arr.dim_0_bins.attrs["units"] = "m" - (mappable,) = arr.plot(**{dim: "dim_0_bins"}) + (mappable,) = arr.plot(**{dim: "dim_0_bins"}) # type: ignore[arg-type] ax = mappable.figure.gca() actual = getattr(ax, f"get_{dim}label")() @@ -747,9 +747,9 @@ def test_labels_with_units_with_interval(self, dim) -> None: def test_multiplot_over_length_one_dim(self) -> None: a = easy_array((3, 1, 1, 1)) d = DataArray(a, dims=("x", "col", "row", "hue")) - d.plot(col="col") - d.plot(row="row") - d.plot(hue="hue") + d.plot(col="col") # type: ignore[call-arg] + d.plot(row="row") # type: ignore[call-arg] + d.plot(hue="hue") # type: ignore[call-arg] class TestPlot1D(PlotTestCase): @@ -760,27 +760,27 @@ def setUp(self) -> None: self.darray.period.attrs["units"] = "s" def test_xlabel_is_index_name(self) -> None: - self.darray.plot() + self.darray.plot() # type: ignore[call-arg] assert "period [s]" == plt.gca().get_xlabel() def test_no_label_name_on_x_axis(self) -> None: - self.darray.plot(y="period") + self.darray.plot(y="period") # type: ignore[call-arg] assert "" == plt.gca().get_xlabel() def test_no_label_name_on_y_axis(self) -> None: - self.darray.plot() + self.darray.plot() # type: ignore[call-arg] assert "" == plt.gca().get_ylabel() def test_ylabel_is_data_name(self) -> None: self.darray.name = "temperature" self.darray.attrs["units"] = "degrees_Celsius" - self.darray.plot() + self.darray.plot() # type: ignore[call-arg] assert "temperature [degrees_Celsius]" == plt.gca().get_ylabel() def test_xlabel_is_data_name(self) -> None: self.darray.name = "temperature" self.darray.attrs["units"] = "degrees_Celsius" - self.darray.plot(y="period") + self.darray.plot(y="period") # type: ignore[call-arg] assert "temperature [degrees_Celsius]" == plt.gca().get_xlabel() def test_format_string(self) -> None: @@ -859,12 +859,12 @@ def test_step_with_hue_and_where(self, where) -> None: assert hdl[0].get_drawstyle() == f"steps-{where}" def test_drawstyle_steps(self) -> None: - hdl = self.darray[0].plot(hue="dim_2", drawstyle="steps") + hdl = self.darray[0].plot(hue="dim_2", drawstyle="steps") # type: ignore[call-arg] assert hdl[0].get_drawstyle() == "steps" @pytest.mark.parametrize("where", ["pre", "post", "mid"]) def test_drawstyle_steps_with_where(self, where) -> None: - hdl = self.darray[0].plot(hue="dim_2", drawstyle=f"steps-{where}") + hdl = self.darray[0].plot(hue="dim_2", drawstyle=f"steps-{where}") # type: ignore[call-arg] assert hdl[0].get_drawstyle() == f"steps-{where}" def test_coord_with_interval_step(self) -> None: @@ -907,29 +907,29 @@ def setUp(self) -> None: self.darray = DataArray(easy_array((2, 3, 4))) def test_3d_array(self) -> None: - self.darray.plot.hist() + self.darray.plot.hist() # type: ignore[call-arg] def test_xlabel_uses_name(self) -> None: self.darray.name = "testpoints" self.darray.attrs["units"] = "testunits" - self.darray.plot.hist() + self.darray.plot.hist() # type: ignore[call-arg] assert "testpoints [testunits]" == plt.gca().get_xlabel() def test_title_is_histogram(self) -> None: self.darray.coords["d"] = 10 - self.darray.plot.hist() + self.darray.plot.hist() # type: ignore[call-arg] assert "d = 10" == plt.gca().get_title() def test_can_pass_in_kwargs(self) -> None: nbins = 5 - self.darray.plot.hist(bins=nbins) + self.darray.plot.hist(bins=nbins) # type: ignore[call-arg] assert nbins == len(plt.gca().patches) def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot.hist) def test_primitive_returned(self) -> None: - n, bins, patches = self.darray.plot.hist() + n, bins, patches = self.darray.plot.hist() # type: ignore[call-arg] assert isinstance(n, np.ndarray) assert isinstance(bins, np.ndarray) assert isinstance(patches, mpl.container.BarContainer) @@ -938,11 +938,11 @@ def test_primitive_returned(self) -> None: @pytest.mark.slow def test_plot_nans(self) -> None: self.darray[0, 0, 0] = np.nan - self.darray.plot.hist() + self.darray.plot.hist() # type: ignore[call-arg] def test_hist_coord_with_interval(self) -> None: ( - self.darray.groupby_bins("dim_0", [-1, 0, 1, 2]) + self.darray.groupby_bins("dim_0", [-1, 0, 1, 2]) # type: ignore[call-arg] .mean(...) .plot.hist(range=(-1, 2)) ) @@ -1146,6 +1146,7 @@ def test_norm_sets_vmin_vmax(self) -> None: ], ["neither", "neither", "both", "max", "min"], [7, None, None, None, None], + strict=True, ): test_min = vmin if norm.vmin is None else norm.vmin test_max = vmax if norm.vmax is None else norm.vmax @@ -1167,7 +1168,7 @@ def setUp(self): y = np.arange(start=9, stop=-7, step=-3) xy = np.dstack(np.meshgrid(x, y)) distance = np.linalg.norm(xy, axis=2) - self.darray = DataArray(distance, list(zip(("y", "x"), (y, x)))) + self.darray = DataArray(distance, list(zip(("y", "x"), (y, x), strict=True))) self.data_min = distance.min() self.data_max = distance.max() yield @@ -1246,7 +1247,7 @@ def test_discrete_colormap_int_levels(self) -> None: def test_discrete_colormap_list_levels_and_vmin_or_vmax(self) -> None: levels = [0, 5, 10, 15] - primitive = self.darray.plot(levels=levels, vmin=-3, vmax=20) + primitive = self.darray.plot(levels=levels, vmin=-3, vmax=20) # type: ignore[call-arg] assert primitive.norm.vmax == max(levels) assert primitive.norm.vmin == min(levels) @@ -1316,11 +1317,11 @@ def test_1d_raises_valueerror(self) -> None: self.plotfunc(self.darray[0, :]) def test_bool(self) -> None: - xr.ones_like(self.darray, dtype=bool).plot() + xr.ones_like(self.darray, dtype=bool).plot() # type: ignore[call-arg] def test_complex_raises_typeerror(self) -> None: with pytest.raises(TypeError, match=r"complex128"): - (self.darray + 1j).plot() + (self.darray + 1j).plot() # type: ignore[call-arg] def test_3d_raises_valueerror(self) -> None: a = DataArray(easy_array((2, 3, 4))) @@ -1720,10 +1721,10 @@ def test_colormap_error_norm_and_vmin_vmax(self) -> None: norm = mpl.colors.LogNorm(0.1, 1e1) with pytest.raises(ValueError): - self.darray.plot(norm=norm, vmin=2) + self.darray.plot(norm=norm, vmin=2) # type: ignore[call-arg] with pytest.raises(ValueError): - self.darray.plot(norm=norm, vmax=2) + self.darray.plot(norm=norm, vmax=2) # type: ignore[call-arg] @pytest.mark.slow @@ -1862,7 +1863,7 @@ def test_dont_infer_interval_breaks_for_cartopy(self) -> None: # Regression for GH 781 ax = plt.gca() # Simulate a Cartopy Axis - setattr(ax, "projection", True) + ax.projection = True # type: ignore[attr-defined] artist = self.plotmethod(x="x2d", y="y2d", ax=ax) assert isinstance(artist, mpl.collections.QuadMesh) # Let cartopy handle the axis limits and artist size @@ -2004,7 +2005,7 @@ def test_plot_rgba_image_transposed(self) -> None: easy_array((4, 10, 15), start=0), dims=["band", "y", "x"] ).plot.imshow() - def test_warns_ambigious_dim(self) -> None: + def test_warns_ambiguous_dim(self) -> None: arr = DataArray(easy_array((3, 3, 3)), dims=["y", "x", "band"]) with pytest.warns(UserWarning): arr.plot.imshow() @@ -2208,7 +2209,7 @@ def test_no_args(self) -> None: def test_names_appear_somewhere(self) -> None: self.darray.name = "testvar" self.g.map_dataarray(xplt.contourf, "x", "y") - for k, ax in zip("abc", self.g.axs.flat): + for k, ax in zip("abc", self.g.axs.flat, strict=True): assert f"z = {k}" == ax.get_title() alltxt = text_in_fig() @@ -2450,11 +2451,15 @@ def test_title_kwargs(self) -> None: g.set_titles(template="{value}", weight="bold") # Rightmost column titles should be bold - for label, ax in zip(self.darray.coords["row"].values, g.axs[:, -1]): + for label, ax in zip( + self.darray.coords["row"].values, g.axs[:, -1], strict=True + ): assert property_in_axes_text("weight", "bold", label, ax) # Top row titles should be bold - for label, ax in zip(self.darray.coords["col"].values, g.axs[0, :]): + for label, ax in zip( + self.darray.coords["col"].values, g.axs[0, :], strict=True + ): assert property_in_axes_text("weight", "bold", label, ax) @pytest.mark.slow @@ -2465,21 +2470,29 @@ def test_default_labels(self) -> None: g.map_dataarray(xplt.imshow, "x", "y") # Rightmost column should be labeled - for label, ax in zip(self.darray.coords["row"].values, g.axs[:, -1]): + for label, ax in zip( + self.darray.coords["row"].values, g.axs[:, -1], strict=True + ): assert substring_in_axes(label, ax) # Top row should be labeled - for label, ax in zip(self.darray.coords["col"].values, g.axs[0, :]): + for label, ax in zip( + self.darray.coords["col"].values, g.axs[0, :], strict=True + ): assert substring_in_axes(label, ax) # ensure that row & col labels can be changed g.set_titles("abc={value}") - for label, ax in zip(self.darray.coords["row"].values, g.axs[:, -1]): + for label, ax in zip( + self.darray.coords["row"].values, g.axs[:, -1], strict=True + ): assert substring_in_axes(f"abc={label}", ax) # previous labels were "row=row0" etc. assert substring_not_in_axes("row=", ax) - for label, ax in zip(self.darray.coords["col"].values, g.axs[0, :]): + for label, ax in zip( + self.darray.coords["col"].values, g.axs[0, :], strict=True + ): assert substring_in_axes(f"abc={label}", ax) # previous labels were "col=row0" etc. assert substring_not_in_axes("col=", ax) @@ -2516,10 +2529,10 @@ def setUp(self) -> None: self.darray.row.attrs["units"] = "rowunits" def test_facetgrid_shape(self) -> None: - g = self.darray.plot(row="row", col="col", hue="hue") + g = self.darray.plot(row="row", col="col", hue="hue") # type: ignore[call-arg] assert g.axs.shape == (len(self.darray.row), len(self.darray.col)) - g = self.darray.plot(row="col", col="row", hue="hue") + g = self.darray.plot(row="col", col="row", hue="hue") # type: ignore[call-arg] assert g.axs.shape == (len(self.darray.col), len(self.darray.row)) def test_unnamed_args(self) -> None: @@ -2532,13 +2545,17 @@ def test_unnamed_args(self) -> None: assert lines[0].get_linestyle() == "--" def test_default_labels(self) -> None: - g = self.darray.plot(row="row", col="col", hue="hue") + g = self.darray.plot(row="row", col="col", hue="hue") # type: ignore[call-arg] # Rightmost column should be labeled - for label, ax in zip(self.darray.coords["row"].values, g.axs[:, -1]): + for label, ax in zip( + self.darray.coords["row"].values, g.axs[:, -1], strict=True + ): assert substring_in_axes(label, ax) # Top row should be labeled - for label, ax in zip(self.darray.coords["col"].values, g.axs[0, :]): + for label, ax in zip( + self.darray.coords["col"].values, g.axs[0, :], strict=True + ): assert substring_in_axes(str(label), ax) # Leftmost column should have array name @@ -2547,7 +2564,7 @@ def test_default_labels(self) -> None: def test_test_empty_cell(self) -> None: g = ( - self.darray.isel(row=1) + self.darray.isel(row=1) # type: ignore[call-arg] .drop_vars("row") .plot(col="col", hue="hue", col_wrap=2) ) @@ -2556,7 +2573,7 @@ def test_test_empty_cell(self) -> None: assert not bottomright.get_visible() def test_set_axis_labels(self) -> None: - g = self.darray.plot(row="row", col="col", hue="hue") + g = self.darray.plot(row="row", col="col", hue="hue") # type: ignore[call-arg] g.set_axis_labels("longitude", "latitude") alltxt = text_in_fig() @@ -2573,7 +2590,7 @@ def test_figsize_and_size(self) -> None: def test_wrong_num_of_dimensions(self) -> None: with pytest.raises(ValueError): - self.darray.plot(row="row", hue="hue") + self.darray.plot(row="row", hue="hue") # type: ignore[call-arg] self.darray.plot.line(row="row", hue="hue") @@ -2784,7 +2801,7 @@ def test_default_labels(self) -> None: g = self.ds.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") # Top row should be labeled - for label, ax in zip(self.ds.coords["col"].values, g.axs[0, :]): + for label, ax in zip(self.ds.coords["col"].values, g.axs[0, :], strict=True): assert substring_in_axes(str(label), ax) # Bottom row should have name of x array name and units @@ -3379,14 +3396,14 @@ def test_plot1d_default_rcparams() -> None: fig, ax = plt.subplots(1, 1) ds.plot.scatter(x="A", y="B", marker="o", ax=ax) actual: np.ndarray = mpl.colors.to_rgba_array("w") - expected: np.ndarray = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + expected: np.ndarray = ax.collections[0].get_edgecolor() # type: ignore[assignment] np.testing.assert_allclose(actual, expected) # Facetgrids should have the default value as well: fg = ds.plot.scatter(x="A", y="B", col="x", marker="o") ax = fg.axs.ravel()[0] actual = mpl.colors.to_rgba_array("w") - expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] np.testing.assert_allclose(actual, expected) # scatter should not emit any warnings when using unfilled markers: @@ -3398,7 +3415,7 @@ def test_plot1d_default_rcparams() -> None: fig, ax = plt.subplots(1, 1) ds.plot.scatter(x="A", y="B", marker="o", ax=ax, edgecolor="k") actual = mpl.colors.to_rgba_array("k") - expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] # mpl error? + expected = ax.collections[0].get_edgecolor() # type: ignore[assignment] np.testing.assert_allclose(actual, expected) @@ -3422,7 +3439,7 @@ def test_9155() -> None: with figure_context(): data = xr.DataArray([1, 2, 3], dims=["x"]) fig, ax = plt.subplots(ncols=1, nrows=1) - data.plot(ax=ax) + data.plot(ax=ax) # type: ignore[call-arg] @requires_matplotlib diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index a0d742c3159..2b321ba8dfc 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -21,22 +21,22 @@ class DummyBackendEntrypointArgs(common.BackendEntrypoint): - def open_dataset(filename_or_obj, *args): + def open_dataset(filename_or_obj, *args): # type: ignore[override] pass class DummyBackendEntrypointKwargs(common.BackendEntrypoint): - def open_dataset(filename_or_obj, **kwargs): + def open_dataset(filename_or_obj, **kwargs): # type: ignore[override] pass class DummyBackendEntrypoint1(common.BackendEntrypoint): - def open_dataset(self, filename_or_obj, *, decoder): + def open_dataset(self, filename_or_obj, *, decoder): # type: ignore[override] pass class DummyBackendEntrypoint2(common.BackendEntrypoint): - def open_dataset(self, filename_or_obj, *, decoder): + def open_dataset(self, filename_or_obj, *, decoder): # type: ignore[override] pass diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 79869e63ae7..9d880969a82 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -312,7 +312,7 @@ def test_rolling_count_correct(self, compute_backend) -> None: DataArray([np.nan, np.nan, 2, 3, 3, 4, 5, 5, 5, 5, 5], dims="time"), ] - for kwarg, expected in zip(kwargs, expecteds): + for kwarg, expected in zip(kwargs, expecteds, strict=True): result = da.rolling(**kwarg).count() assert_equal(result, expected) diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py index 79ae4769005..9fdf46b0d85 100644 --- a/xarray/tests/test_strategies.py +++ b/xarray/tests/test_strategies.py @@ -138,7 +138,9 @@ def fixed_array_strategy_fn(*, shape=None, dtype=None): return st.just(arr) dim_names = data.draw(dimension_names(min_dims=arr.ndim, max_dims=arr.ndim)) - dim_sizes = {name: size for name, size in zip(dim_names, arr.shape)} + dim_sizes = { + name: size for name, size in zip(dim_names, arr.shape, strict=True) + } var = data.draw( variables( diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index a7de2e39af6..d9d581cc314 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -55,6 +55,15 @@ def test_parent_swap(self): assert steve.children["Mary"] is mary assert "Mary" not in john.children + def test_forbid_setting_parent_directly(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + + with pytest.raises( + AttributeError, match="Cannot set parent attribute directly" + ): + mary.parent = john + def test_multi_child_family(self): mary: TreeNode = TreeNode() kate: TreeNode = TreeNode() @@ -295,7 +304,7 @@ def test_ancestors(self): _, leaf_f = create_test_tree() ancestors = leaf_f.ancestors expected = ["a", "b", "e", "f"] - for node, expected_name in zip(ancestors, expected): + for node, expected_name in zip(ancestors, expected, strict=True): assert node.name == expected_name def test_subtree(self): @@ -312,7 +321,7 @@ def test_subtree(self): "g", "i", ] - for node, expected_name in zip(subtree, expected): + for node, expected_name in zip(subtree, expected, strict=True): assert node.name == expected_name def test_descendants(self): @@ -328,7 +337,7 @@ def test_descendants(self): "g", "i", ] - for node, expected_name in zip(descendants, expected): + for node, expected_name in zip(descendants, expected, strict=True): assert node.name == expected_name def test_leaves(self): @@ -340,7 +349,7 @@ def test_leaves(self): "g", "i", ] - for node, expected_name in zip(leaves, expected): + for node, expected_name in zip(leaves, expected, strict=True): assert node.name == expected_name def test_levels(self): @@ -378,7 +387,7 @@ def test_render_nodetree(self): john_nodes = john_repr.splitlines() assert len(john_nodes) == len(expected_nodes) - for expected_node, repr_node in zip(expected_nodes, john_nodes): + for expected_node, repr_node in zip(expected_nodes, john_nodes, strict=True): assert expected_node == repr_node diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 232a35d8ea0..ced569ffeab 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1997,7 +1997,7 @@ def test_masking(self, func, unit, error, dtype): def test_squeeze(self, dim, dtype): shape = (2, 1, 3, 1, 1, 2) names = list("abcdef") - dim_lengths = dict(zip(names, shape)) + dim_lengths = dict(zip(names, shape, strict=True)) array = np.ones(shape=shape) * unit_registry.m variable = xr.Variable(names, array) @@ -3429,7 +3429,7 @@ def test_drop_sel(self, raw_values, unit, error, dtype): ) def test_squeeze(self, shape, dim, dtype): names = "xyzt" - dim_lengths = dict(zip(names, shape)) + dim_lengths = dict(zip(names, shape, strict=False)) names = "xyzt" array = np.arange(10 * 20).astype(dtype).reshape(shape) * unit_registry.J data_array = xr.DataArray(data=array, dims=tuple(names[: len(shape)])) @@ -3659,7 +3659,10 @@ def test_to_unstacked_dataset(self, dtype): expected = attach_units( func(strip_units(data_array)), - {"y": y.units, **dict(zip(x.magnitude, [array.units] * len(y)))}, + { + "y": y.units, + **dict(zip(x.magnitude, [array.units] * len(y), strict=True)), + }, ).rename({elem.magnitude: elem for elem in x}) actual = func(data_array) @@ -5072,7 +5075,7 @@ def test_head_tail_thin(self, func, variant, dtype): ) def test_squeeze(self, shape, dim, dtype): names = "xyzt" - dim_lengths = dict(zip(names, shape)) + dim_lengths = dict(zip(names, shape, strict=False)) array1 = ( np.linspace(0, 1, 10 * 20).astype(dtype).reshape(shape) * unit_registry.degK ) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index ecec3ca507b..86e34d151a8 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Hashable +from types import EllipsisType import numpy as np import pandas as pd @@ -288,7 +289,7 @@ def test_parse_dims_set() -> None: @pytest.mark.parametrize( "dim", [pytest.param(None, id="None"), pytest.param(..., id="ellipsis")] ) -def test_parse_dims_replace_none(dim: None | ellipsis) -> None: +def test_parse_dims_replace_none(dim: None | EllipsisType) -> None: all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables actual = utils.parse_dims(dim, all_dims, replace_none=True) assert actual == all_dims diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index a1d8994a736..9e2e12fc045 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -576,7 +576,7 @@ def test_copy_deep_recursive(self) -> None: # lets just ensure that deep copy works without RecursionError v.copy(deep=True) - # indirect recusrion + # indirect recursion v2 = self.cls("y", [2, 3]) v.attrs["other"] = v2 v2.attrs["other"] = v @@ -654,7 +654,7 @@ def test_aggregate_complex(self): expected = Variable((), 0.5 + 1j) assert_allclose(v.mean(), expected) - def test_pandas_cateogrical_dtype(self): + def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) print(v) # should not error @@ -1017,7 +1017,7 @@ def test_nd_rolling(self, center, dims): fill_value=np.nan, ) expected = x - for dim, win, cent in zip(dims, window, center): + for dim, win, cent in zip(dims, window, center, strict=True): expected = expected.rolling_window( dim=dim, window=win, @@ -1575,13 +1575,13 @@ def test_transpose_0d(self): actual = variable.transpose() assert_identical(actual, variable) - def test_pandas_cateogrical_dtype(self): + def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) print(v) # should not error assert pd.api.types.is_extension_array_dtype(v.dtype) - def test_pandas_cateogrical_no_chunk(self): + def test_pandas_categorical_no_chunk(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) with pytest.raises( @@ -1806,7 +1806,8 @@ def raise_if_called(*args, **kwargs): @pytest.mark.parametrize("skipna", [True, False, None]) @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) @pytest.mark.parametrize( - "axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]) + "axis, dim", + zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]], strict=True), ) def test_quantile(self, q, axis, dim, skipna): d = self.d.copy() @@ -2386,7 +2387,7 @@ def test_multiindex(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) - def test_pandas_cateogrical_dtype(self): + def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): self.cls("x", data) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index f3337d70a76..93a200c07a6 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -24,7 +24,7 @@ def test_weighted_non_DataArray_weights(as_dataset: bool) -> None: data = data.to_dataset(name="data") with pytest.raises(ValueError, match=r"`weights` must be a DataArray"): - data.weighted([1, 2]) # type: ignore + data.weighted([1, 2]) # type: ignore[arg-type] @pytest.mark.parametrize("as_dataset", (True, False)) diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py index c8e594508d9..0cdee2bd564 100644 --- a/xarray/util/deprecation_helpers.py +++ b/xarray/util/deprecation_helpers.py @@ -108,7 +108,9 @@ def inner(*args, **kwargs): stacklevel=2, ) - zip_args = zip(kwonly_args[:n_extra_args], args[-n_extra_args:]) + zip_args = zip( + kwonly_args[:n_extra_args], args[-n_extra_args:], strict=True + ) kwargs.update({name: arg for name, arg in zip_args}) return func(*args[:-n_extra_args], **kwargs) @@ -142,4 +144,4 @@ def wrapper(*args, **kwargs): # We're quite confident we're just returning `T` from this function, so it's fine to ignore typing # within the function. - return wrapper # type: ignore + return wrapper # type: ignore[return-value]