diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 165d7fd3..1d7a79b9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,11 +18,11 @@ jobs: strategy: fail-fast: false matrix: - python: ["3.9", "3.12"] + python: ["3.10", "3.12"] os: [ubuntu-latest] include: - os: macos-latest - python: "3.9" + python: "3.10" - os: macos-latest python: "3.12" pip-flags: "--pre" diff --git a/.mypy.ini b/.mypy.ini index a47d0f87..77bf7465 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.9 +python_version = 3.10 plugins = numpy.typing.mypy_plugin ignore_errors = False diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c3d0777..25a5f04d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,14 +2,14 @@ fail_fast: false default_language_version: python: python3 default_stages: - - commit - - push + - pre-commit + - pre-push minimum_pre_commit_version: 2.16.0 ci: skip: [] repos: - repo: https://github.com/psf/black - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-prettier @@ -17,17 +17,17 @@ repos: hooks: - id: prettier - repo: https://github.com/asottile/blacken-docs - rev: 1.18.0 + rev: 1.19.1 hooks: - id: blacken-docs - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy additional_dependencies: [numpy, types-requests] exclude: tests/|docs/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.8 + rev: v0.7.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/CHANGELOG.md b/CHANGELOG.md index c5519278..7a93b602 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,28 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html -## [0.2.4] - xxxx-xx-xx +## [0.2.6] - TBD + +### Fixed + +- Updated deprecated default stages of `pre-commit` #771 + +## [0.2.5] - 2024-06-11 + +### Fixed + +- Incompatibility issues due to newest release of `multiscale-spatial-image` #760 + +## [0.2.4] - 2024-06-11 + +### Major + +- Enable vectorization of `bounding_box_query` for all `SpatialData` elements. #699 ### Minor -- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems` -- Added `get_pyramid_levels()` utils API +- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems` #714 +- Added `get_pyramid_levels()` utils API #719 - Improved ergonomics of `concatenate()` when element names are non-unique #720 - Improved performance of writing images with multiscales #577 diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 9281cb12..45c4d0ed 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 9281cb127ccd849a14daf83fe79c7cb21613c662 +Subproject commit 45c4d0edd826dcf472725991ad688f80f1d1dd5a diff --git a/pyproject.toml b/pyproject.toml index 683c1ed7..4433b067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ maintainers = [ urls.Documentation = "https://spatialdata.scverse.org/en/latest" urls.Source = "https://github.com/scverse/spatialdata.git" urls.Home-page = "https://github.com/scverse/spatialdata.git" -requires-python = ">=3.9" +requires-python = ">=3.10, <3.13" # include 3.13 once multiscale-spatial-image conflicts are resolved dynamic= [ "version" # allow version to be set by git tags ] @@ -28,8 +28,9 @@ dependencies = [ "dask>=2024.4.1", "fsspec<=2023.6", "geopandas>=0.14", - "multiscale_spatial_image>=1.0.0", + "multiscale_spatial_image>=2.0.1", "networkx", + "numba", "numpy", "ome_zarr>=0.8.4", "pandas", @@ -42,8 +43,7 @@ dependencies = [ "scikit-image", "scipy", "typing_extensions>=4.8.0", - "xarray", - "xarray-datatree", + "xarray>=2024.10.0", "xarray-schema", "xarray-spatial>=0.3.5", "zarr", @@ -69,6 +69,7 @@ test = [ "pytest", "pytest-cov", "pytest-mock", + "torch", ] torch = [ "torch" @@ -100,7 +101,7 @@ filterwarnings = [ [tool.black] line-length = 120 -target-version = ['py39'] +target-version = ['py310'] include = '\.pyi?$' exclude = ''' ( @@ -145,7 +146,7 @@ exclude = [ "setup.py", ] line-length = 120 -target-version = "py39" +target-version = "py310" [tool.ruff.lint] ignore = [ diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index 819d1e2a..1ce04409 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -7,9 +7,8 @@ from dask.array.core import Array as DaskArray from dask.array.core import from_array from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.spatialdata import SpatialData from spatialdata.models._utils import SpatialElement diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index c8b07fdc..24ea616c 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -7,10 +7,9 @@ import pandas as pd import xarray as xr from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.operations.transform import transform from spatialdata.models import get_axes_names @@ -86,7 +85,7 @@ def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame: centroids[label_value] += count * i.values.item() all_labels_values, all_labels_counts = da.unique(xdata.data, return_counts=True) - all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute())) + all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute(), strict=True)) for label_value in centroids: centroids[label_value] /= all_labels[label_value] centroids = dict(sorted(centroids.items(), key=lambda x: x[0])) @@ -132,7 +131,7 @@ def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame: if isinstance(first_geometry, Point): xy = e.geometry.get_coordinates().values else: - assert isinstance(first_geometry, (Polygon, MultiPolygon)), ( + assert isinstance(first_geometry, Polygon | MultiPolygon), ( f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or" f" Polygons/MultiPolygons. Found {type(first_geometry)} instead." ) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 6be34850..6b91e22d 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -52,7 +52,7 @@ def _concatenate_tables( raise ValueError("`instance_key` must be specified if tables have different instance keys") tables_l = [] - for table_region_key, table_instance_key, table in zip(region_keys, instance_keys, tables): + for table_region_key, table_instance_key, table in zip(region_keys, instance_keys, tables, strict=True): rename_dict = {} if table_region_key != region_key: rename_dict[table_region_key] = region_key @@ -247,7 +247,7 @@ def _fix_ensure_unique_element_names( tables[new_name] = table tables_by_sdata.append(tables) sdatas_fixed = [] - for elements, tables in zip(elements_by_sdata, tables_by_sdata): + for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True): sdata = SpatialData.init_from_elements(elements, tables=tables) sdatas_fixed.append(sdata) return sdatas_fixed diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index eddfd94b..ba7f12c6 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -8,10 +8,9 @@ import numpy as np import pandas as pd from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.operations.transform import transform from spatialdata._core.spatialdata import SpatialData @@ -71,7 +70,7 @@ def _get_extent_of_polygons_multipolygons( ------- The bounding box description. """ - assert isinstance(shapes.geometry.iloc[0], (Polygon, MultiPolygon)) + assert isinstance(shapes.geometry.iloc[0], Polygon | MultiPolygon) axes = get_axes_names(shapes) bounds = shapes["geometry"].bounds return {ax: (bounds[f"min{ax}"].min(), bounds[f"max{ax}"].max()) for ax in axes} @@ -201,7 +200,7 @@ def _( new_max_coordinates_dict: dict[str, list[float]] = defaultdict(list) mask = [has_images, has_labels, has_points, has_shapes] include_spatial_elements = ["images", "labels", "points", "shapes"] - include_spatial_elements = [i for (i, v) in zip(include_spatial_elements, mask) if v] + include_spatial_elements = [i for (i, v) in zip(include_spatial_elements, mask, strict=True) if v] if elements is None: # to shut up ruff elements = [] @@ -217,7 +216,7 @@ def _( assert isinstance(transformations, dict) coordinate_systems = list(transformations.keys()) if coordinate_system in coordinate_systems: - if isinstance(element_obj, (DaskDataFrame, GeoDataFrame)): + if isinstance(element_obj, DaskDataFrame | GeoDataFrame): extent = get_extent(element_obj, coordinate_system=coordinate_system, exact=exact) else: extent = get_extent(element_obj, coordinate_system=coordinate_system) @@ -254,7 +253,7 @@ def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: first_geometry = e_temp["geometry"].iloc[0] if isinstance(first_geometry, Point): return _get_extent_of_circles(e) - assert isinstance(first_geometry, (Polygon, MultiPolygon)) + assert isinstance(first_geometry, Polygon | MultiPolygon) return _get_extent_of_polygons_multipolygons(e) diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index 0cb0bf10..e78fccd3 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -2,8 +2,7 @@ from typing import TYPE_CHECKING -from datatree import DataTree -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata.models import SpatialElement @@ -115,7 +114,7 @@ def transform_to_data_extent( } for _, element_name, element in sdata_raster.gen_spatial_elements(): - if isinstance(element, (DataArray, DataTree)): + if isinstance(element, DataArray | DataTree): rasterized = rasterize( element, axes=data_extent_axes, diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index ff029f7f..dde0338a 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -10,11 +10,10 @@ import numpy as np import pandas as pd from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from scipy import sparse from shapely import Point -from xarray import DataArray +from xarray import DataArray, DataTree from xrspatial import zonal_stats from spatialdata._core.operations._utils import _parse_element @@ -246,7 +245,7 @@ def _create_sdata_from_table_and_shapes( table = TableModel.parse(table, region=shapes_name, region_key=region_key, instance_key=instance_key) # labels case, needs conversion from str to int - if isinstance(shapes, (DataArray, DataTree)): + if isinstance(shapes, DataArray | DataTree): table.obs[instance_key] = table.obs[instance_key].astype(int) if deepcopy: diff --git a/src/spatialdata/_core/operations/map.py b/src/spatialdata/_core/operations/map.py index b3810352..16bd696f 100644 --- a/src/spatialdata/_core/operations/map.py +++ b/src/spatialdata/_core/operations/map.py @@ -1,13 +1,12 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import dask.array as da from dask.array.overlap import coerce_depth -from datatree import DataTree -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims from spatialdata.transformations import get_transformation diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index e7172bf1..b7bdb929 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -5,10 +5,9 @@ import numpy as np from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from shapely import Point -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.operations._utils import _parse_element from spatialdata._core.operations.transform import transform diff --git a/src/spatialdata/_core/operations/rasterize_bins.py b/src/spatialdata/_core/operations/rasterize_bins.py index 18d72d04..c8e5e648 100644 --- a/src/spatialdata/_core/operations/rasterize_bins.py +++ b/src/spatialdata/_core/operations/rasterize_bins.py @@ -73,7 +73,7 @@ def rasterize_bins( """ element = sdata[bins] table = sdata.tables[table_name] - if not isinstance(element, (GeoDataFrame, DaskDataFrame)): + if not isinstance(element, GeoDataFrame | DaskDataFrame): raise ValueError("The bins should be a GeoDataFrame or a DaskDataFrame.") _, region_key, instance_key = get_table_keys(table) @@ -94,7 +94,7 @@ def rasterize_bins( keys = ([value_key] if isinstance(value_key, str) else value_key) if value_key is not None else table.var_names if (value_key is None or any(key in table.var_names for key in keys)) and not isinstance( - table.X, (csc_matrix, np.ndarray) + table.X, csc_matrix | np.ndarray ): raise ValueError( "To speed up bins rasterization, the X matrix in the table, when sparse, should be a csc_matrix matrix. " @@ -162,7 +162,7 @@ def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike: sub_x = sub_df.geometry.x.values sub_y = sub_df.geometry.y.values else: - assert isinstance(sub_df.iloc[0].geometry, (Polygon, MultiPolygon)) + assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon) sub_x = sub_df.centroid.x sub_y = sub_df.centroid.y else: diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index 3282809c..e1fdca74 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -10,10 +10,9 @@ import numpy as np from dask.array.core import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from shapely import Point -from xarray import DataArray +from xarray import DataArray, Dataset, DataTree from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike @@ -161,13 +160,13 @@ def _set_transformation_for_transformed_elements( assert to_coordinate_system is None to_prepend: BaseTransformation | None - if isinstance(element, (DataArray, DataTree)): + if isinstance(element, DataArray | DataTree): if maintain_positioning: assert raster_translation is not None to_prepend = Sequence([raster_translation, transformation.inverse()]) else: to_prepend = raster_translation - elif isinstance(element, (GeoDataFrame, DaskDataFrame)): + elif isinstance(element, GeoDataFrame | DaskDataFrame): assert raster_translation is None to_prepend = transformation.inverse() if maintain_positioning else Identity() else: @@ -393,8 +392,11 @@ def _( raster_translation = raster_translation_single_scale # we set a dummy empty dict for the transformation that will be replaced with the correct transformation for # each scale later in this function, when calling set_transformation() - transformed_dict[k] = DataArray(transformed_dask, dims=xdata.dims, name=xdata.name, attrs={TRANSFORM_KEY: {}}) + transformed_dict[k] = Dataset( + {"image": DataArray(transformed_dask, dims=xdata.dims, name=xdata.name, attrs={TRANSFORM_KEY: {}})} + ) if channel_names is not None: + # This expression returns a dataset now. transformed_dict[k] = transformed_dict[k].assign_coords(c=channel_names) # mypy thinks that schema could be ShapesModel, PointsModel, ... diff --git a/src/spatialdata/_core/operations/vectorize.py b/src/spatialdata/_core/operations/vectorize.py index 766a0d96..d90156fb 100644 --- a/src/spatialdata/_core/operations/vectorize.py +++ b/src/spatialdata/_core/operations/vectorize.py @@ -9,11 +9,10 @@ import shapely import skimage.measure from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon from skimage.measure._regionprops import RegionProperties -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.centroids import get_centroids from spatialdata._core.operations.aggregate import aggregate @@ -106,7 +105,7 @@ def _(element: DataArray | DataTree, **kwargs: Any) -> GeoDataFrame: @to_circles.register(GeoDataFrame) def _(element: GeoDataFrame, **kwargs: Any) -> GeoDataFrame: assert len(kwargs) == 0 - if isinstance(element.geometry.iloc[0], (Polygon, MultiPolygon)): + if isinstance(element.geometry.iloc[0], Polygon | MultiPolygon): radius = np.sqrt(element.geometry.area / np.pi) centroids = _get_centroids(element) obs = pd.DataFrame({"radius": radius}) @@ -258,7 +257,7 @@ def _dissolve_on_overlaps(label: int, group: GeoDataFrame) -> GeoDataFrame: @to_polygons.register(GeoDataFrame) def _(gdf: GeoDataFrame, buffer_resolution: int = 16) -> GeoDataFrame: - if isinstance(gdf.geometry.iloc[0], (Polygon, MultiPolygon)): + if isinstance(gdf.geometry.iloc[0], Polygon | MultiPolygon): return gdf if isinstance(gdf.geometry.iloc[0], Point): ShapesModel.validate_shapes_not_mixed_types(gdf) @@ -274,7 +273,7 @@ def _(gdf: GeoDataFrame, buffer_resolution: int = 16) -> GeoDataFrame: # TODO replace with a function to copy the metadata (the parser could also do this): https://github.com/scverse/spatialdata/issues/258 buffered_df.attrs[ShapesModel.TRANSFORM_KEY] = gdf.attrs[ShapesModel.TRANSFORM_KEY] return buffered_df - assert isinstance(gdf.geometry.iloc[0], (Polygon, MultiPolygon)) + assert isinstance(gdf.geometry.iloc[0], Polygon | MultiPolygon) return gdf raise RuntimeError("Unsupported geometry type: " f"{type(gdf.geometry.iloc[0])}") diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index 1f39212c..0229d8bd 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -5,8 +5,7 @@ import numba as nb import numpy as np from anndata import AnnData -from datatree import DataTree -from xarray import DataArray +from xarray import DataArray, Dataset, DataTree from spatialdata._core._elements import Tables from spatialdata._core.spatialdata import SpatialData @@ -96,7 +95,7 @@ def get_bounding_box_corners( return output.squeeze().drop_vars("box") -@nb.njit(parallel=False, nopython=True) +@nb.jit(parallel=False, nopython=True) def _create_slices_and_translation( min_values: nb.types.Array, max_values: nb.types.Array, @@ -138,7 +137,7 @@ def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: if len(scales_to_keep) == 0: return None - d = {k: d[k] for k in scales_to_keep} + d = {k: Dataset({"image": d[k]}) for k in scales_to_keep} result = DataTree.from_dict(d) # Rechunk the data to avoid irregular chunks diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 76de711c..c29a22a0 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -13,9 +13,8 @@ import pandas as pd from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike @@ -653,7 +652,7 @@ def join_spatialelement_table( if sdata is not None: elements_dict = _create_sdata_elements_dict_for_join(sdata, spatial_element_names) else: - derived_sdata = SpatialData.from_elements_dict(dict(zip(spatial_element_names, spatial_elements))) + derived_sdata = SpatialData.from_elements_dict(dict(zip(spatial_element_names, spatial_elements, strict=True))) element_types = ["labels", "shapes", "points"] elements_dict = defaultdict(lambda: defaultdict(dict)) for element_type in element_types: @@ -919,7 +918,7 @@ def get_values( x = matched_table[:, value_key_values].X import scipy - if isinstance(x, (scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix)): + if isinstance(x, scipy.sparse.csr_matrix | scipy.sparse.csc_matrix | scipy.sparse.coo_matrix): x = x.todense() df = pd.DataFrame(x, columns=value_key_values) if origin == "obsm": diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 32d93f1a..a1868606 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -2,19 +2,18 @@ import warnings from abc import abstractmethod -from collections.abc import Mapping +from collections.abc import Callable, Mapping from dataclasses import dataclass from functools import singledispatch -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import dask.array as da import dask.dataframe as dd import numpy as np from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from shapely.geometry import MultiPolygon, Point, Polygon -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata import to_polygons from spatialdata._core.query._utils import ( @@ -603,7 +602,7 @@ def _( if isinstance(query_result, list): processed_results = [] - for result, translation_vector in zip(query_result, translation_vectors): + for result, translation_vector in zip(query_result, translation_vectors, strict=True): processed_result = _process_query_result(result, translation_vector, axes) if processed_result is not None: processed_results.append(processed_result) @@ -688,7 +687,7 @@ def _( # transform the element to the query coordinate system output: list[DaskDataFrame | None] = [] - for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate): + for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate, strict=True): if p is None: output.append(None) else: diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 77bd6df9..c0caf8aa 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -14,12 +14,11 @@ from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import read_parquet from dask.delayed import Delayed -from datatree import DataTree from geopandas import GeoDataFrame from ome_zarr.io import parse_url from ome_zarr.types import JSONDict from shapely import MultiPolygon, Polygon -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables from spatialdata._logging import logger @@ -468,7 +467,7 @@ def set_table_annotates_spatialelement( table = self.tables[table_name] element_names = {element[1] for element in self._gen_elements()} if (isinstance(region, str) and region not in element_names) or ( - isinstance(region, (list, pd.Series)) + isinstance(region, list | pd.Series) and not all(region_element in element_names for region_element in region) ): raise ValueError(f"Annotation target '{region}' not present as SpatialElement in SpatialData object.") @@ -546,7 +545,7 @@ def path(self) -> Path | None: @path.setter def path(self, value: Path | None) -> None: - if value is None or isinstance(value, (str, Path)): + if value is None or isinstance(value, str | Path): self._path = value else: raise TypeError("Path must be `None`, a `str` or a `Path` object.") @@ -1480,13 +1479,13 @@ def write_transformations(self, element_name: str | None = None) -> None: zarr_path=Path(self.path), element_type=element_type, element_name=element_name ) axes = get_axes_names(element) - if isinstance(element, (DataArray, DataTree)): + if isinstance(element, DataArray | DataTree): from spatialdata._io._utils import ( overwrite_coordinate_transformations_raster, ) overwrite_coordinate_transformations_raster(group=element_group, axes=axes, transformations=transformations) - elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): + elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData): from spatialdata._io._utils import ( overwrite_coordinate_transformations_non_raster, ) diff --git a/src/spatialdata/_docs.py b/src/spatialdata/_docs.py index d3e9dae7..b8e13bc7 100644 --- a/src/spatialdata/_docs.py +++ b/src/spatialdata/_docs.py @@ -1,5 +1,6 @@ # from https://stackoverflow.com/questions/10307696/how-to-put-a-variable-into-python-docstring -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar T = TypeVar("T") diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 18b86f1a..4c2a401e 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -16,9 +16,8 @@ from anndata import AnnData from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import get_pyramid_levels @@ -53,7 +52,7 @@ def _get_transformations_from_ngff_dict( list_of_ngff_transformations = [NgffBaseTransformation.from_dict(d) for d in list_of_encoded_ngff_transformations] list_of_transformations = [BaseTransformation.from_ngff(t) for t in list_of_ngff_transformations] transformations = {} - for ngff_t, t in zip(list_of_ngff_transformations, list_of_transformations): + for ngff_t, t in zip(list_of_ngff_transformations, list_of_transformations, strict=True): assert ngff_t.output_coordinate_system is not None transformations[ngff_t.output_coordinate_system.name] = t return transformations @@ -213,7 +212,7 @@ def get_dask_backing_files(element: SpatialData | SpatialElement | AnnData) -> l def _(element: SpatialData) -> list[str]: files: set[str] = set() for e in element._gen_spatial_element_values(): - if isinstance(e, (DataArray, DataTree, DaskDataFrame)): + if isinstance(e, DataArray | DataTree | DaskDataFrame): files = files.union(get_dask_backing_files(e)) return list(files) diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 8569c456..b992c287 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -57,7 +57,7 @@ def generate_coordinate_transformations(self, shapes: list[tuple[Any]]) -> None # calculate minimal 'scale' transform based on pyramid dims for shape in shapes: assert len(shape) == len(data_shape) - scale = [full / level for full, level in zip(data_shape, shape)] + scale = [full / level for full, level in zip(data_shape, shape, strict=True)] from spatialdata.transformations.ngff.ngff_transformations import NgffScale coordinate_transformations.append([NgffScale(scale=scale).to_dict()]) @@ -98,7 +98,7 @@ def validate_coordinate_transformations( json1 = [json.dumps(p.to_dict()) for p in parsed] import numpy as np - assert np.all([j0 == j1 for j0, j1 in zip(json0, json1)]) + assert np.all([j0 == j1 for j0, j1 in zip(json0, json1, strict=True)]) # eventually we are fully compliant with NGFF and we can drop SPATIALDATA_FORMAT_VERSION and simply rely on # "version"; still, until the coordinate transformations make it into NGFF, we need to have our extension diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 3a047734..d4113fb3 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -1,7 +1,6 @@ import os from collections.abc import MutableMapping from pathlib import Path -from typing import Union import zarr from dask.dataframe import DataFrame as DaskDataFrame # type: ignore[attr-defined] @@ -22,10 +21,10 @@ def _read_points( - store: Union[str, Path, MutableMapping, zarr.Group], # type: ignore[type-arg] + store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg] ) -> DaskDataFrame: """Read points from a zarr store.""" - assert isinstance(store, (str, Path)) + assert isinstance(store, str | Path) f = zarr.open(store, mode="r") version = _parse_version(f, expect_attrs_key=True) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 3e3c8bc8..c7059bc8 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,10 +1,9 @@ from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import dask.array as da import numpy as np import zarr -from datatree import DataTree from ome_zarr.format import Format from ome_zarr.io import ZarrLocation from ome_zarr.reader import Label, Multiscales, Node, Reader @@ -14,7 +13,7 @@ from ome_zarr.writer import write_labels as write_labels_ngff from ome_zarr.writer import write_multiscale as write_multiscale_ngff from ome_zarr.writer import write_multiscale_labels as write_multiscale_labels_ngff -from xarray import DataArray +from xarray import DataArray, Dataset, DataTree from spatialdata._io._utils import ( _get_transformations_from_ngff_dict, @@ -37,8 +36,8 @@ ) -def _read_multiscale(store: Union[str, Path], raster_type: Literal["image", "labels"]) -> Union[DataArray, DataTree]: - assert isinstance(store, (str, Path)) +def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: + assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] f = zarr.open(store, mode="r") @@ -82,7 +81,7 @@ def _read_multiscale(store: Union[str, Path], raster_type: Literal["image", "lab # TODO: what to do with name? For now remove? # name = os.path.basename(node.metadata["name"]) # if image, read channels metadata - channels: Optional[list[Any]] = None + channels: list[Any] | None = None if raster_type == "image": if legacy_channels_metadata is not None: channels = [d["label"] for d in legacy_channels_metadata["channels"]] @@ -93,11 +92,15 @@ def _read_multiscale(store: Union[str, Path], raster_type: Literal["image", "lab multiscale_image = {} for i, d in enumerate(datasets): data = node.load(Multiscales).array(resolution=d, version=format.version) - multiscale_image[f"scale{i}"] = DataArray( - data, - name="image", - dims=axes, - coords={"c": channels} if channels is not None else {}, + multiscale_image[f"scale{i}"] = Dataset( + { + "image": DataArray( + data, + name="image", + dims=axes, + coords={"c": channels} if channels is not None else {}, + ) + } ) msi = DataTree.from_dict(multiscale_image) _set_transformations(msi, transformations) @@ -115,13 +118,13 @@ def _read_multiscale(store: Union[str, Path], raster_type: Literal["image", "lab def _write_raster( raster_type: Literal["image", "labels"], - raster_data: Union[DataArray, DataTree], + raster_data: DataArray | DataTree, group: zarr.Group, name: str, format: Format = CurrentRasterFormat(), - storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None, - label_metadata: Optional[JSONDict] = None, - **metadata: Union[str, JSONDict, list[JSONDict]], + storage_options: JSONDict | list[JSONDict] | None = None, + label_metadata: JSONDict | None = None, + **metadata: str | JSONDict | list[JSONDict], ) -> None: assert raster_type in ["image", "labels"] # the argument "name" and "label_metadata" are only used for labels (to be precise, name is used in @@ -229,12 +232,12 @@ def _get_group_for_writing_transformations() -> zarr.Group: def write_image( - image: Union[DataArray, DataTree], + image: DataArray | DataTree, group: zarr.Group, name: str, format: Format = CurrentRasterFormat(), - storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None, - **metadata: Union[str, JSONDict, list[JSONDict]], + storage_options: JSONDict | list[JSONDict] | None = None, + **metadata: str | JSONDict | list[JSONDict], ) -> None: _write_raster( raster_type="image", @@ -248,12 +251,12 @@ def write_image( def write_labels( - labels: Union[DataArray, DataTree], + labels: DataArray | DataTree, group: zarr.Group, name: str, format: Format = CurrentRasterFormat(), - storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None, - label_metadata: Optional[JSONDict] = None, + storage_options: JSONDict | list[JSONDict] | None = None, + label_metadata: JSONDict | None = None, **metadata: JSONDict, ) -> None: _write_raster( diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 2ede8f45..c32ce1f3 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -1,6 +1,5 @@ from collections.abc import MutableMapping from pathlib import Path -from typing import Union import numpy as np import zarr @@ -28,10 +27,10 @@ def _read_shapes( - store: Union[str, Path, MutableMapping, zarr.Group], # type: ignore[type-arg] + store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg] ) -> GeoDataFrame: """Read shapes from a zarr store.""" - assert isinstance(store, (str, Path)) + assert isinstance(store, str | Path) f = zarr.open(store, mode="r") version = _parse_version(f, expect_attrs_key=True) assert version is not None diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index bb4d2540..2737d24b 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -2,7 +2,6 @@ import os import warnings from pathlib import Path -from typing import Optional, Union import zarr from anndata import AnnData @@ -16,7 +15,7 @@ from spatialdata._logging import logger -def _open_zarr_store(store: Union[str, Path, zarr.Group]) -> tuple[zarr.Group, str]: +def _open_zarr_store(store: str | Path | zarr.Group) -> tuple[zarr.Group, str]: """ Open a zarr store (on-disk or remote) and return the zarr.Group object and the path to the store. @@ -31,13 +30,13 @@ def _open_zarr_store(store: Union[str, Path, zarr.Group]) -> tuple[zarr.Group, s """ f = store if isinstance(store, zarr.Group) else zarr.open(store, mode="r") # workaround: .zmetadata is being written as zmetadata (https://github.com/zarr-developers/zarr-python/issues/1121) - if isinstance(store, (str, Path)) and str(store).startswith("http") and len(f) == 0: + if isinstance(store, str | Path) and str(store).startswith("http") and len(f) == 0: f = zarr.open_consolidated(store, mode="r", metadata_key="zmetadata") f_store_path = f.store.store.path if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore) else f.store.path return f, f_store_path -def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str]] = None) -> SpatialData: +def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = None) -> SpatialData: """ Read a SpatialData dataset from a zarr store (on-disk or remote). diff --git a/src/spatialdata/_types.py b/src/spatialdata/_types.py index 554c4e63..c0f22b32 100644 --- a/src/spatialdata/_types.py +++ b/src/spatialdata/_types.py @@ -1,10 +1,5 @@ -from __future__ import annotations - -from typing import Union - import numpy as np -from datatree import DataTree -from xarray import DataArray +from xarray import DataArray, DataTree __all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T"] @@ -14,7 +9,7 @@ ArrayLike = NDArray[np.float64] except (ImportError, TypeError): ArrayLike = np.ndarray # type: ignore[misc] - DTypeLike = np.dtype # type: ignore[misc] + DTypeLike = np.dtype # type: ignore[misc, assignment] -Raster_T = Union[DataArray, DataTree] -ColorLike = Union[tuple[float, ...], str] +Raster_T = DataArray | DataTree +ColorLike = tuple[float, ...] | str diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 0e45f380..4d78b7d6 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -3,16 +3,15 @@ import functools import re import warnings -from collections.abc import Generator +from collections.abc import Callable, Generator from itertools import islice -from typing import Any, Callable, TypeVar, Union +from typing import Any, TypeVar import numpy as np import pandas as pd from anndata import AnnData from dask import array as da -from datatree import DataTree -from xarray import DataArray +from xarray import DataArray, Dataset, DataTree from spatialdata._types import ArrayLike from spatialdata.transformations import ( @@ -23,7 +22,7 @@ ) # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: -Number = Union[int, float] +Number = int | float RT = TypeVar("RT") @@ -136,7 +135,7 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]: assert len(v.values()) == 1 xdata = v.values().__iter__().__next__() if 0 not in xdata.shape: - d[k] = xdata + d[k] = Dataset({"image": xdata}) unpadded = DataTree.from_dict(d) else: raise TypeError(f"Unsupported type: {type(raster)}") diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index c65bc4f5..0b3e7e18 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,22 +1,21 @@ from __future__ import annotations import warnings -from collections.abc import Mapping +from collections.abc import Callable, Mapping from functools import partial from itertools import chain from types import MappingProxyType -from typing import Any, Callable +from typing import Any import anndata as ad import numpy as np import pandas as pd from anndata import AnnData -from datatree import DataTree from geopandas import GeoDataFrame from pandas import CategoricalDtype from scipy.sparse import issparse from torch.utils.data import Dataset -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.centroids import get_centroids from spatialdata._core.operations.transform import transform diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index ded4cc51..05b1b334 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -10,13 +10,12 @@ import scipy from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from numpy.random import default_rng from shapely.affinity import translate from shapely.geometry import MultiPolygon, Point, Polygon from skimage.segmentation import slic -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.operations.aggregate import aggregate from spatialdata._core.query.relational_query import get_element_instances @@ -183,7 +182,7 @@ def _image_blobs( masks = [] for i in range(n_channels): mask = self._generate_blobs(length=length, seed=i) - mask = (mask - mask.min()) / mask.ptp() # type: ignore[attr-defined] + mask = (mask - mask.min()) / np.ptp(mask) # type: ignore[attr-defined] masks.append(mask) x = np.stack(masks, axis=0) diff --git a/src/spatialdata/models/__init__.py b/src/spatialdata/models/__init__.py index 0c2bacd8..250410be 100644 --- a/src/spatialdata/models/__init__.py +++ b/src/spatialdata/models/__init__.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from spatialdata.models._utils import ( C, SpatialElement, @@ -52,4 +50,5 @@ "get_table_keys", "get_channels", "force_2d", + "RasterSchema", ] diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index 71cf8c9c..f5df85e3 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -1,8 +1,6 @@ -from __future__ import annotations - import warnings from functools import singledispatch -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, TypeAlias import dask.dataframe as dd import geopandas @@ -10,15 +8,14 @@ import pandas as pd from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from shapely.geometry import MultiPolygon, Point, Polygon -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._logging import logger from spatialdata.transformations.transformations import BaseTransformation -SpatialElement = Union[DataArray, DataTree, GeoDataFrame, DaskDataFrame] +SpatialElement: TypeAlias = DataArray | DataTree | GeoDataFrame | DaskDataFrame TRANSFORM_KEY = "transform" DEFAULT_COORDINATE_SYSTEM = "global" ValidAxis_t = str @@ -48,9 +45,9 @@ def has_type_spatial_element(e: Any) -> bool: Returns ------- Whether the object is a SpatialElement - (i.e in Union[DataArray, DataTree, GeoDataFrame, DaskDataFrame]) + (i.e in DataArray | DataTree | GeoDataFrame | DaskDataFrame) """ - return isinstance(e, (DataArray, DataTree, GeoDataFrame, DaskDataFrame)) + return isinstance(e, DataArray | DataTree | GeoDataFrame | DaskDataFrame) # added this code as part of a refactoring to catch errors earlier @@ -341,7 +338,7 @@ def force_2d(gdf: GeoDataFrame) -> None: gdf.geometry = new_shapes -def get_raster_model_from_data_dims(dims: tuple[str, ...]) -> type[RasterSchema]: +def get_raster_model_from_data_dims(dims: tuple[str, ...]) -> type["RasterSchema"]: """ Get the raster model from the dimensions of the data. diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 72d92780..9b417e43 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -1,12 +1,10 @@ """Models and schema for SpatialData.""" -from __future__ import annotations - import warnings from collections.abc import Mapping, Sequence from functools import singledispatchmethod from pathlib import Path -from typing import Any, Literal, Union +from typing import Any, Literal, TypeAlias import dask.dataframe as dd import numpy as np @@ -15,7 +13,6 @@ from dask.array import Array as DaskArray from dask.array.core import from_array from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame, GeoSeries from multiscale_spatial_image import to_multiscale from multiscale_spatial_image.to_multiscale.to_multiscale import Methods @@ -25,7 +22,7 @@ from shapely.geometry.collection import GeometryCollection from shapely.io import from_geojson, from_ragged_array from spatial_image import to_spatial_image -from xarray import DataArray +from xarray import DataArray, DataTree from xarray_schema.components import ( ArrayTypeSchema, AttrSchema, @@ -53,13 +50,8 @@ from spatialdata.transformations.transformations import BaseTransformation, Identity # Types -Chunks_t = Union[ - int, - tuple[int, ...], - tuple[tuple[int, ...], ...], - Mapping[Any, Union[None, int, tuple[int, ...]]], -] -ScaleFactors_t = Sequence[Union[dict[str, int], int]] +Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] +ScaleFactors_t = Sequence[dict[str, int] | int] Transform_s = AttrSchema(BaseTransformation, None) ATTRS_KEY = "spatialdata_attrs" @@ -157,7 +149,7 @@ def parse( if "name" in kwargs: raise ValueError("The `name` argument is not (yet) supported for raster data.") # if dims is specified inside the data, get the value of dims from the data - if isinstance(data, (DataArray)): + if isinstance(data, DataArray): if not isinstance(data.data, DaskArray): # numpy -> dask data.data = from_array(data.data) if dims is not None: @@ -173,7 +165,7 @@ def parse( raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.") _reindex = lambda d: d # if there are no dims in the data, use the model's dims or provided dims - elif isinstance(data, (np.ndarray, DaskArray)): + elif isinstance(data, np.ndarray | DaskArray): if not isinstance(data, DaskArray): # numpy -> dask data = from_array(data) if dims is None: @@ -247,7 +239,7 @@ def _(self, data: DataArray) -> None: @validate.register(DataTree) def _(self, data: DataTree) -> None: - for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))]): + for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True): if j != k: raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.") name = {list(data[i].data_vars.keys())[0] for i in data} @@ -367,7 +359,7 @@ def validate(cls, data: GeoDataFrame) -> None: if len(data[cls.GEOMETRY_KEY]) == 0: raise ValueError(f"Column `{cls.GEOMETRY_KEY}` is empty." + SUGGESTION) geom_ = data[cls.GEOMETRY_KEY].values[0] - if not isinstance(geom_, (Polygon, MultiPolygon, Point)): + if not isinstance(geom_, Polygon | MultiPolygon | Point): raise ValueError( f"Column `{cls.GEOMETRY_KEY}` can only contain `Point`, `Polygon` or `MultiPolygon` shapes," f"but it contains {type(geom_)}." + SUGGESTION @@ -1035,15 +1027,15 @@ def parse( return convert_region_column_to_categorical(adata) -Schema_t = Union[ - type[Image2DModel], - type[Image3DModel], - type[Labels2DModel], - type[Labels3DModel], - type[PointsModel], - type[ShapesModel], - type[TableModel], -] +Schema_t: TypeAlias = ( + type[Image2DModel] + | type[Image3DModel] + | type[Labels2DModel] + | type[Labels3DModel] + | type[PointsModel] + | type[ShapesModel] + | type[TableModel] +) def get_model( @@ -1069,7 +1061,7 @@ def _validate_and_return( schema().validate(e) return schema - if isinstance(e, (DataArray, DataTree)): + if isinstance(e, DataArray | DataTree): axes = get_axes_names(e) if "c" in axes: if "z" in axes: diff --git a/src/spatialdata/testing.py b/src/spatialdata/testing.py index ef9fcd2c..253f6e50 100644 --- a/src/spatialdata/testing.py +++ b/src/spatialdata/testing.py @@ -4,12 +4,10 @@ from anndata.tests.helpers import assert_equal as assert_anndata_equal from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe.tests.test_dataframe import assert_eq as assert_dask_dataframe_equal -from datatree import DataTree -from datatree.testing import assert_equal as assert_datatree_equal from geopandas import GeoDataFrame from geopandas.testing import assert_geodataframe_equal -from xarray import DataArray -from xarray.testing import assert_equal as assert_xarray_equal +from xarray import DataArray, DataTree +from xarray.testing import assert_equal from spatialdata import SpatialData from spatialdata._core._elements import Elements @@ -115,10 +113,8 @@ def assert_elements_are_identical( # compare the elements if isinstance(element0, AnnData): assert_anndata_equal(element0, element1) - elif isinstance(element0, DataArray): - assert_xarray_equal(element0, element1) - elif isinstance(element0, DataTree): - assert_datatree_equal(element0, element1) + elif isinstance(element0, DataArray | DataTree): + assert_equal(element0, element1) elif isinstance(element0, GeoDataFrame): assert_geodataframe_equal(element0, element1, check_less_precise=True) else: diff --git a/src/spatialdata/transformations/_utils.py b/src/spatialdata/transformations/_utils.py index 98645e96..3e203924 100644 --- a/src/spatialdata/transformations/_utils.py +++ b/src/spatialdata/transformations/_utils.py @@ -5,9 +5,8 @@ import numpy as np from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame -from xarray import DataArray +from xarray import DataArray, Dataset, DataTree from spatialdata._logging import logger from spatialdata._types import ArrayLike @@ -223,8 +222,10 @@ def _(data: DataTree) -> DataTree: offset = max_dim / n / 2 coords = np.linspace(0, max_dim, n + 1)[:-1] + offset new_coords[ax] = coords - out[name] = dt[img_name].assign_coords(new_coords) - datatree = DataTree.from_dict(d=out) + + # Xarray now only accepts Dataset as dictionary values for DataTree.from_dict. + out[name] = Dataset({img_name: dt[img_name].assign_coords(new_coords)}) + datatree = DataTree.from_dict(out) # this is to trigger the validation of the dims _ = get_axes_names(datatree) return datatree diff --git a/src/spatialdata/transformations/ngff/ngff_coordinate_system.py b/src/spatialdata/transformations/ngff/ngff_coordinate_system.py index 0cd7d35c..7f5a825f 100644 --- a/src/spatialdata/transformations/ngff/ngff_coordinate_system.py +++ b/src/spatialdata/transformations/ngff/ngff_coordinate_system.py @@ -27,9 +27,9 @@ class NgffAxis: name: str type: str - unit: Optional[str] + unit: str | None - def __init__(self, name: str, type: str, unit: Optional[str] = None): + def __init__(self, name: str, type: str, unit: str | None = None): self.name = name self.type = type self.unit = unit @@ -138,7 +138,7 @@ def equal_up_to_the_name(self, other: NgffCoordinateSystem) -> bool: """Checks if two coordinate systems are the same based on the axes (ignoring the coordinate systems names).""" return self._axes == other._axes - def subset(self, axes_names: list[str], new_name: Optional[str] = None) -> NgffCoordinateSystem: + def subset(self, axes_names: list[str], new_name: str | None = None) -> NgffCoordinateSystem: """ Return a new coordinate system subsetting the axes. @@ -193,7 +193,7 @@ def get_axis(self, name: str) -> NgffAxis: @staticmethod def merge( - coord_sys1: NgffCoordinateSystem, coord_sys2: NgffCoordinateSystem, new_name: Optional[str] = None + coord_sys1: NgffCoordinateSystem, coord_sys2: NgffCoordinateSystem, new_name: str | None = None ) -> NgffCoordinateSystem: """ Merge two coordinate systems @@ -256,7 +256,7 @@ def _get_spatial_axes( return [axis.name for axis in coordinate_system._axes if axis.type == "space"] -def _make_cs(ndim: Literal[2, 3], name: Optional[str] = None, unit: Optional[str] = None) -> NgffCoordinateSystem: +def _make_cs(ndim: Literal[2, 3], name: str | None = None, unit: str | None = None) -> NgffCoordinateSystem: """helper function to make a yx or zyx coordinate system""" if ndim == 2: axes = [ @@ -278,7 +278,7 @@ def _make_cs(ndim: Literal[2, 3], name: Optional[str] = None, unit: Optional[str return NgffCoordinateSystem(name=name, axes=axes) -def yx_cs(name: Optional[str] = None, unit: Optional[str] = None) -> NgffCoordinateSystem: +def yx_cs(name: str | None = None, unit: str | None = None) -> NgffCoordinateSystem: """ Helper function to create a 2D yx coordinate system. @@ -296,7 +296,7 @@ def yx_cs(name: Optional[str] = None, unit: Optional[str] = None) -> NgffCoordin return _make_cs(name=name, ndim=2, unit=unit) -def zyx_cs(name: Optional[str] = None, unit: Optional[str] = None) -> NgffCoordinateSystem: +def zyx_cs(name: str | None = None, unit: str | None = None) -> NgffCoordinateSystem: """ Helper function to create a 3D zyx coordinate system. diff --git a/src/spatialdata/transformations/ngff/ngff_transformations.py b/src/spatialdata/transformations/ngff/ngff_transformations.py index 123a926b..fc8e0634 100644 --- a/src/spatialdata/transformations/ngff/ngff_transformations.py +++ b/src/spatialdata/transformations/ngff/ngff_transformations.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import math from abc import ABC, abstractmethod from numbers import Number @@ -32,7 +30,6 @@ # http://api.csswg.org/bikeshed/?url=https://raw.githubusercontent.com/bogovicj/ngff/coord-transforms/latest/index.bs # Transformation_t = Dict[str, Union[str, List[int], List[str], List[Dict[str, Any]]]] Transformation_t = dict[str, Any] -NGFF_TRANSFORMATIONS: dict[str, type[NgffBaseTransformation]] = {} class NgffBaseTransformation(ABC): @@ -75,11 +72,11 @@ def __repr__(self) -> str: @classmethod @abstractmethod - def _from_dict(cls, d: Transformation_t) -> NgffBaseTransformation: + def _from_dict(cls, d: Transformation_t) -> "NgffBaseTransformation": pass @classmethod - def from_dict(cls, d: Transformation_t) -> NgffBaseTransformation: + def from_dict(cls, d: Transformation_t) -> "NgffBaseTransformation": """ Initialize a transformation from the Python dict of its json representation. @@ -138,7 +135,7 @@ def _update_dict_with_input_output_cs(self, d: Transformation_t) -> None: d["output"] = d["output"].to_dict() @abstractmethod - def inverse(self) -> NgffBaseTransformation: + def inverse(self) -> "NgffBaseTransformation": """Return the inverse of the transformation.""" @abstractmethod @@ -160,7 +157,7 @@ def transform_points(self, points: ArrayLike) -> ArrayLike: """ @abstractmethod - def to_affine(self) -> NgffAffine: + def to_affine(self) -> "NgffAffine": """Convert the transformation to an affine transformation, whenever the conversion can be made.""" def _validate_transform_points_shapes(self, input_size: int, points_shape: tuple[int, ...]) -> None: @@ -181,7 +178,7 @@ def _validate_transform_points_shapes(self, input_size: int, points_shape: tuple ) # order of the composition: self is applied first, then the transformation passed as argument - def compose_with(self, transformation: NgffBaseTransformation) -> NgffBaseTransformation: + def compose_with(self, transformation: "NgffBaseTransformation") -> "NgffBaseTransformation": """ Compose the transfomation object with another transformation @@ -227,6 +224,9 @@ def _parse_list_into_array(array: Union[list[Number], list[list[Number]], ArrayL return array +NGFF_TRANSFORMATIONS: dict[str, type[NgffBaseTransformation]] = {} + + # A note on affine transformations and their matrix representation. # Some transformations can be interpreted as (n-dimensional) affine transformations; explicitly these transformations # are: @@ -277,6 +277,118 @@ def _parse_list_into_array(array: Union[list[Number], list[list[Number]], ArrayL # space invariant (i.e. it does not map finite points to the line at the infinity). # For a primer you can look here: https://en.wikipedia.org/wiki/Affine_space#Relation_to_projective_spaces # For more information please consult a linear algebra textbook. + + +class NgffAffine(NgffBaseTransformation): + """The Affine transformation from the NGFF specification.""" + + def __init__( + self, + affine: Union[ArrayLike, list[list[Number]]], + input_coordinate_system: Optional[NgffCoordinateSystem] = None, + output_coordinate_system: Optional[NgffCoordinateSystem] = None, + ) -> None: + """ + Init the NgffAffine object. + Parameters + ---------- + affine + A list of lists of numbers or a matrix specifying the affine transformation. + input_coordinate_system + Input coordinate system of the transformation. + output_coordinate_system + Output coordinate system of the transformation. + """ + super().__init__(input_coordinate_system, output_coordinate_system) + self.affine = self._parse_list_into_array(affine) + + @classmethod + def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type] + assert isinstance(d["affine"], list) + last_row = [[0.0] * (len(d["affine"][0]) - 1) + [1.0]] + return cls(d["affine"] + last_row) + + def to_dict(self) -> Transformation_t: + d = { + "type": "affine", + "affine": self.affine[:-1, :].tolist(), + } + self._update_dict_with_input_output_cs(d) + return d + + def _repr_transformation_description(self, indent: int = 0) -> str: + s = "" + for row in self.affine: + s += f"{self._indent(indent)}{row}\n" + s = s[:-1] + return s + + def inverse(self) -> NgffBaseTransformation: + inv = np.linalg.inv(self.affine) + return NgffAffine( + inv, + input_coordinate_system=self.output_coordinate_system, + output_coordinate_system=self.input_coordinate_system, + ) + + def _get_and_validate_axes(self) -> tuple[tuple[str, ...], tuple[str, ...]]: + input_axes, output_axes = self._get_axes_from_coordinate_systems() + return input_axes, output_axes + + def transform_points(self, points: ArrayLike) -> ArrayLike: + input_axes, output_axes = self._get_and_validate_axes() + self._validate_transform_points_shapes(len(input_axes), points.shape) + p = np.vstack([points.T, np.ones(points.shape[0])]) + q = self.affine @ p + return q[: len(output_axes), :].T # type: ignore[no-any-return] + + def to_affine(self) -> "NgffAffine": + return NgffAffine( + self.affine, + input_coordinate_system=self.input_coordinate_system, + output_coordinate_system=self.output_coordinate_system, + ) + + @classmethod + def _affine_matrix_from_input_and_output_axes( + cls, input_axes: tuple[str, ...], output_axes: tuple[str, ...] + ) -> ArrayLike: + """ + computes a permutation matrix to match the input and output axes. + + Parameters + ---------- + input_axes + the input axes. + output_axes + the output axes. + """ + from spatialdata.models import C, X, Y, Z + + assert all(ax in (X, Y, Z, C) for ax in input_axes) + assert all(ax in (X, Y, Z, C) for ax in output_axes) + m = np.zeros((len(output_axes) + 1, len(input_axes) + 1)) + for output_ax in output_axes: + for input_ax in input_axes: + if output_ax == input_ax: + m[output_axes.index(output_ax), input_axes.index(input_ax)] = 1 + m[-1, -1] = 1 + return m + + @classmethod + def from_input_output_coordinate_systems( + cls, + input_coordinate_system: NgffCoordinateSystem, + output_coordinate_system: NgffCoordinateSystem, + ) -> "NgffAffine": + input_axes = input_coordinate_system.axes_names + output_axes = output_coordinate_system.axes_names + m = cls._affine_matrix_from_input_and_output_axes(input_axes, output_axes) + return cls( + affine=m, input_coordinate_system=input_coordinate_system, output_coordinate_system=output_coordinate_system + ) + + class NgffIdentity(NgffBaseTransformation): """The Identity transformation from the NGFF specification.""" @@ -569,116 +681,6 @@ def to_affine(self) -> NgffAffine: ) -class NgffAffine(NgffBaseTransformation): - """The Affine transformation from the NGFF specification.""" - - def __init__( - self, - affine: Union[ArrayLike, list[list[Number]]], - input_coordinate_system: Optional[NgffCoordinateSystem] = None, - output_coordinate_system: Optional[NgffCoordinateSystem] = None, - ) -> None: - """ - Init the NgffAffine object. - Parameters - ---------- - affine - A list of lists of numbers or a matrix specifying the affine transformation. - input_coordinate_system - Input coordinate system of the transformation. - output_coordinate_system - Output coordinate system of the transformation. - """ - super().__init__(input_coordinate_system, output_coordinate_system) - self.affine = self._parse_list_into_array(affine) - - @classmethod - def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type] - assert isinstance(d["affine"], list) - last_row = [[0.0] * (len(d["affine"][0]) - 1) + [1.0]] - return cls(d["affine"] + last_row) - - def to_dict(self) -> Transformation_t: - d = { - "type": "affine", - "affine": self.affine[:-1, :].tolist(), - } - self._update_dict_with_input_output_cs(d) - return d - - def _repr_transformation_description(self, indent: int = 0) -> str: - s = "" - for row in self.affine: - s += f"{self._indent(indent)}{row}\n" - s = s[:-1] - return s - - def inverse(self) -> NgffBaseTransformation: - inv = np.linalg.inv(self.affine) - return NgffAffine( - inv, - input_coordinate_system=self.output_coordinate_system, - output_coordinate_system=self.input_coordinate_system, - ) - - def _get_and_validate_axes(self) -> tuple[tuple[str, ...], tuple[str, ...]]: - input_axes, output_axes = self._get_axes_from_coordinate_systems() - return input_axes, output_axes - - def transform_points(self, points: ArrayLike) -> ArrayLike: - input_axes, output_axes = self._get_and_validate_axes() - self._validate_transform_points_shapes(len(input_axes), points.shape) - p = np.vstack([points.T, np.ones(points.shape[0])]) - q = self.affine @ p - return q[: len(output_axes), :].T # type: ignore[no-any-return] - - def to_affine(self) -> NgffAffine: - return NgffAffine( - self.affine, - input_coordinate_system=self.input_coordinate_system, - output_coordinate_system=self.output_coordinate_system, - ) - - @classmethod - def _affine_matrix_from_input_and_output_axes( - cls, input_axes: tuple[str, ...], output_axes: tuple[str, ...] - ) -> ArrayLike: - """ - computes a permutation matrix to match the input and output axes. - - Parameters - ---------- - input_axes - the input axes. - output_axes - the output axes. - """ - from spatialdata.models import C, X, Y, Z - - assert all(ax in (X, Y, Z, C) for ax in input_axes) - assert all(ax in (X, Y, Z, C) for ax in output_axes) - m = np.zeros((len(output_axes) + 1, len(input_axes) + 1)) - for output_ax in output_axes: - for input_ax in input_axes: - if output_ax == input_ax: - m[output_axes.index(output_ax), input_axes.index(input_ax)] = 1 - m[-1, -1] = 1 - return m - - @classmethod - def from_input_output_coordinate_systems( - cls, - input_coordinate_system: NgffCoordinateSystem, - output_coordinate_system: NgffCoordinateSystem, - ) -> NgffAffine: - input_axes = input_coordinate_system.axes_names - output_axes = output_coordinate_system.axes_names - m = cls._affine_matrix_from_input_and_output_axes(input_axes, output_axes) - return cls( - affine=m, input_coordinate_system=input_coordinate_system, output_coordinate_system=output_coordinate_system - ) - - class NgffRotation(NgffBaseTransformation): """The Rotation transformation from the NGFF specification.""" @@ -843,7 +845,7 @@ def _inferring_cs_infer_output_coordinate_system( assert isinstance(t.input_coordinate_system, NgffCoordinateSystem) if isinstance(t, NgffAffine): return None - elif isinstance(t, (NgffTranslation, NgffScale, NgffRotation, NgffIdentity)): + elif isinstance(t, NgffTranslation | NgffScale | NgffRotation | NgffIdentity): return t.input_coordinate_system elif isinstance(t, NgffMapAxis): return None @@ -1130,7 +1132,7 @@ def transform_points(self, points: ArrayLike) -> ArrayLike: input_columns = [points[:, input_axes.index(ax)] for ax in t.input_coordinate_system.axes_names] input_columns_stacked: ArrayLike = np.stack(input_columns, axis=1) output_columns_t = t.transform_points(input_columns_stacked) - for ax, col in zip(t.output_coordinate_system.axes_names, output_columns_t.T): + for ax, col in zip(t.output_coordinate_system.axes_names, output_columns_t.T, strict=True): output_columns[ax] = col output: ArrayLike = np.stack([output_columns[ax] for ax in output_axes], axis=1) return output diff --git a/src/spatialdata/transformations/operations.py b/src/spatialdata/transformations/operations.py index 21115257..354eea34 100644 --- a/src/spatialdata/transformations/operations.py +++ b/src/spatialdata/transformations/operations.py @@ -21,7 +21,7 @@ def set_transformation( element: SpatialElement, transformation: Union[BaseTransformation, dict[str, BaseTransformation]], - to_coordinate_system: Optional[str] = None, + to_coordinate_system: str | None = None, set_all: bool = False, write_to_sdata: Optional[SpatialData] = None, ) -> None: @@ -89,7 +89,7 @@ def set_transformation( def get_transformation( - element: SpatialElement, to_coordinate_system: Optional[str] = None, get_all: bool = False + element: SpatialElement, to_coordinate_system: str | None = None, get_all: bool = False ) -> Union[BaseTransformation, dict[str, BaseTransformation]]: """ Get the transformation/s of an element. @@ -132,9 +132,9 @@ def get_transformation( def remove_transformation( element: SpatialElement, - to_coordinate_system: Optional[str] = None, + to_coordinate_system: str | None = None, remove_all: bool = False, - write_to_sdata: Optional[SpatialData] = None, + write_to_sdata: SpatialData | None = None, ) -> None: """ Remove a transformation/s from an element, in-memory or from disk. @@ -443,7 +443,7 @@ def align_elements_using_landmarks( moving_element: SpatialElement, reference_coordinate_system: str = "global", moving_coordinate_system: str = "global", - new_coordinate_system: Optional[str] = None, + new_coordinate_system: str | None = None, write_to_sdata: Optional[SpatialData] = None, ) -> BaseTransformation: """ diff --git a/src/spatialdata/transformations/transformations.py b/src/spatialdata/transformations/transformations.py index 585a731a..cda49231 100644 --- a/src/spatialdata/transformations/transformations.py +++ b/src/spatialdata/transformations/transformations.py @@ -84,16 +84,16 @@ def to_ngff( self, input_axes: tuple[ValidAxis_t, ...], output_axes: tuple[ValidAxis_t, ...], - unit: Optional[str] = None, - output_coordinate_system_name: Optional[str] = None, + unit: str | None = None, + output_coordinate_system_name: str | None = None, ) -> NgffBaseTransformation: pass def _get_default_coordinate_system( self, axes: tuple[ValidAxis_t, ...], - unit: Optional[str] = None, - name: Optional[str] = None, + unit: str | None = None, + name: str | None = None, default_to_global: bool = False, ) -> NgffCoordinateSystem: from spatialdata.transformations.ngff._utils import ( @@ -223,8 +223,8 @@ def to_ngff( self, input_axes: tuple[ValidAxis_t, ...], output_axes: tuple[ValidAxis_t, ...], - unit: Optional[str] = None, - output_coordinate_system_name: Optional[str] = None, + unit: str | None = None, + output_coordinate_system_name: str | None = None, ) -> NgffBaseTransformation: input_cs = self._get_default_coordinate_system(axes=input_axes, unit=unit) output_cs = self._get_default_coordinate_system( @@ -317,8 +317,8 @@ def to_ngff( self, input_axes: tuple[ValidAxis_t, ...], output_axes: tuple[ValidAxis_t, ...], - unit: Optional[str] = None, - output_coordinate_system_name: Optional[str] = None, + unit: str | None = None, + output_coordinate_system_name: str | None = None, ) -> NgffBaseTransformation: input_cs = self._get_default_coordinate_system(axes=input_axes, unit=unit) output_cs = self._get_default_coordinate_system( @@ -399,8 +399,8 @@ def to_ngff( self, input_axes: tuple[ValidAxis_t, ...], output_axes: tuple[ValidAxis_t, ...], - unit: Optional[str] = None, - output_coordinate_system_name: Optional[str] = None, + unit: str | None = None, + output_coordinate_system_name: str | None = None, ) -> NgffBaseTransformation: input_cs = self._get_default_coordinate_system(axes=input_axes, unit=unit) output_cs = self._get_default_coordinate_system( @@ -485,8 +485,8 @@ def to_ngff( self, input_axes: tuple[ValidAxis_t, ...], output_axes: tuple[ValidAxis_t, ...], - unit: Optional[str] = None, - output_coordinate_system_name: Optional[str] = None, + unit: str | None = None, + output_coordinate_system_name: str | None = None, ) -> NgffBaseTransformation: input_cs = self._get_default_coordinate_system(axes=input_axes, unit=unit) output_cs = self._get_default_coordinate_system( @@ -593,8 +593,8 @@ def to_ngff( self, input_axes: tuple[ValidAxis_t, ...], output_axes: tuple[ValidAxis_t, ...], - unit: Optional[str] = None, - output_coordinate_system_name: Optional[str] = None, + unit: str | None = None, + output_coordinate_system_name: str | None = None, ) -> NgffBaseTransformation: new_matrix = self.to_affine_matrix(input_axes, output_axes) input_cs = self._get_default_coordinate_system(axes=input_axes, unit=unit) @@ -716,8 +716,8 @@ def to_ngff( self, input_axes: tuple[ValidAxis_t, ...], output_axes: tuple[ValidAxis_t, ...], - unit: Optional[str] = None, - output_coordinate_system_name: Optional[str] = None, + unit: str | None = None, + output_coordinate_system_name: str | None = None, ) -> NgffBaseTransformation: input_cs = self._get_default_coordinate_system(axes=input_axes, unit=unit) output_cs = self._get_default_coordinate_system( @@ -755,7 +755,7 @@ def __eq__(self, other: Any) -> bool: def _get_current_output_axes( transformation: BaseTransformation, input_axes: tuple[ValidAxis_t, ...] ) -> tuple[ValidAxis_t, ...]: - if isinstance(transformation, (Identity, Translation, Scale)): + if isinstance(transformation, Identity | Translation | Scale): return input_axes elif isinstance(transformation, MapAxis): map_axis_input_axes = set(transformation.map_axis.values()) diff --git a/tests/conftest.py b/tests/conftest.py index 2ce363bf..e4498c7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,14 +14,13 @@ import pytest from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from numpy.random import default_rng from scipy import ndimage as ndi from shapely import linearrings, polygons from shapely.geometry import MultiPolygon, Point, Polygon from skimage import data -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core._deepcopy import deepcopy from spatialdata._core.spatialdata import SpatialData @@ -324,7 +323,7 @@ def _make_points(coordinates: np.ndarray) -> DaskDataFrame: def _make_squares(centroid_coordinates: np.ndarray, half_widths: list[float]) -> polygons: linear_rings = [] - for centroid, half_width in zip(centroid_coordinates, half_widths): + for centroid, half_width in zip(centroid_coordinates, half_widths, strict=True): min_coords = centroid - half_width max_coords = centroid + half_width diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 4f10c9c4..540161c7 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -1,5 +1,3 @@ -from typing import Optional - import geopandas import numpy as np import pandas as pd @@ -19,7 +17,7 @@ def _parse_shapes( - sdata_query_aggregation: SpatialData, by_shapes: Optional[str] = None, values_shapes: Optional[str] = None + sdata_query_aggregation: SpatialData, by_shapes: str | None = None, values_shapes: str | None = None ) -> GeoDataFrame: # only one between by_shapes and values_shapes can be None assert by_shapes is None or values_shapes is None diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index 9607d337..9ce5618b 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -6,12 +6,11 @@ import pytest from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from shapely import MultiPolygon, box from spatial_image import SpatialImage -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata import SpatialData, get_extent from spatialdata._core.operations.rasterize import rasterize diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 6ef5ee39..e8031820 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -4,9 +4,8 @@ import numpy as np import pytest -from datatree import DataTree from geopandas.testing import geom_almost_equals -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata import transform from spatialdata._core.data_extent import are_extents_equal, get_extent @@ -541,13 +540,13 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa assert set(d.keys()) == {"global", "my_space"} a2 = d["global"].to_affine_matrix(input_axes=("x",), output_axes=("x",)) assert np.allclose(a, a2) - if isinstance(element, (DataArray, DataTree)): + if isinstance(element, DataArray | DataTree): assert np.allclose(a, np.array([[1 / k, 0], [0, 1]])) else: assert np.allclose(a, np.array([[1 / k, -k / k], [0, 1]])) else: assert set(d.keys()) == {"my_space"} - if isinstance(element, (DataArray, DataTree)): + if isinstance(element, DataArray | DataTree): assert np.allclose(a, np.array([[1, k], [0, 1]])) else: assert np.allclose(a, np.array([[1, 0], [0, 1]])) @@ -605,7 +604,7 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( # I'd say that in the general case maybe they are not necessarily identical, but in this case they are assert np.allclose(affine, affine2) assert np.allclose(affine, np.array([[1, -k], [0, 1]])) - elif isinstance(element, (DataArray, DataTree)): + elif isinstance(element, DataArray | DataTree): assert set(d.keys()) == {"my_space"} assert np.allclose(affine, np.array([[1, k], [0, 1]])) else: @@ -616,7 +615,7 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( if full_sdata.locate_element(element) == ["shapes/proxy_element"]: # non multi-hop case, since there is a direct transformation assert np.allclose(affine, np.array([[1, 0], [0, 1]])) - elif isinstance(element, (DataArray, DataTree)): + elif isinstance(element, DataArray | DataTree): assert np.allclose(affine, np.array([[1, k], [0, 1]])) else: assert np.allclose(affine, np.array([[1, 0], [0, 1]])) diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index bc50a5a4..176e6060 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -8,10 +8,9 @@ import xarray from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.data_extent import get_extent from spatialdata._core.query.spatial_query import ( @@ -312,7 +311,7 @@ def test_query_raster( slices["z"] = slice(2, 7) if return_request_only: - assert isinstance(image_result, (dict, list)) + assert isinstance(image_result, dict | list) if multiple_boxes: for i, result in enumerate(image_result): if not (is_bb_3d and is_3d) and ("z" in result): @@ -334,16 +333,16 @@ def test_query_raster( expected_image = ximage.sel(**slices) if isinstance(image, DataArray): - assert isinstance(image_result, (DataArray, list)) + assert isinstance(image_result, DataArray | list) if multiple_boxes: - for result, expected in zip(image_result, expected_images): + for result, expected in zip(image_result, expected_images, strict=True): np.testing.assert_allclose(result, expected) else: np.testing.assert_allclose(image_result, expected_image) elif isinstance(image, DataTree): - assert isinstance(image_result, (DataTree, list)) + assert isinstance(image_result, DataTree | list) if multiple_boxes: - for result, expected in zip(image_result, expected_images): + for result, expected in zip(image_result, expected_images, strict=True): v = result["scale0"].values() assert len(v) == 1 xdata = v.__iter__().__next__() @@ -795,7 +794,7 @@ def test_query_with_clipping(sdata_blobs): maxy = 210 x_coords = [minx, maxx, maxx, minx, minx] y_coords = [miny, miny, maxy, maxy, miny] - polygon = Polygon(zip(x_coords, y_coords)) + polygon = Polygon(zip(x_coords, y_coords, strict=True)) queried_circles = polygon_query(circles, polygon=polygon, target_coordinate_system="global", clip=True) queried_polygons = polygon_query(polygons, polygon=polygon, target_coordinate_system="global", clip=True) diff --git a/tests/dataloader/__init__.py b/tests/dataloader/__init__.py index 9d24d590..5f165b15 100644 --- a/tests/dataloader/__init__.py +++ b/tests/dataloader/__init__.py @@ -3,7 +3,7 @@ try: from spatialdata.dataloader.datasets import ImageTilesDataset except ImportError as e: - _error: Union[str, None] = str(e) + _error: str | None = str(e) else: _error = None diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 9cc303d7..ed290f04 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import pytest from shapely import GeometryType @@ -18,9 +18,9 @@ class TestFormat: @pytest.mark.parametrize("instance_key", [None, PointsModel.INSTANCE_KEY]) def test_format_points( self, - attrs_key: Optional[str], - feature_key: Optional[str], - instance_key: Optional[str], + attrs_key: str | None, + feature_key: str | None, + instance_key: str | None, ) -> None: metadata: dict[str, Any] = {attrs_key: {"version": Points_f.spatialdata_format_version}} format_metadata: dict[str, Any] = {attrs_key: {}} diff --git a/tests/io/test_pyramids_performance.py b/tests/io/test_pyramids_performance.py index 0f6c21ec..bd9d1a8f 100644 --- a/tests/io/test_pyramids_performance.py +++ b/tests/io/test_pyramids_performance.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import dask import dask.array @@ -7,7 +7,6 @@ import pytest import xarray as xr import zarr -from datatree import DataTree from spatialdata import SpatialData from spatialdata._io import write_image @@ -35,8 +34,8 @@ def sdata_with_image(request: "_pytest.fixtures.SubRequest", tmp_path: Path) -> return SpatialData(images={"image": image}) -def count_chunks(array: Union[xr.DataArray, xr.Dataset, DataTree]) -> int: - if isinstance(array, DataTree): +def count_chunks(array: xr.DataArray | xr.Dataset | xr.DataTree) -> int: + if isinstance(array, xr.DataTree): array = array.ds # From `chunksizes`, we get only the number of chunks per axis. # By multiplying them, we get the total number of chunks in 2D/3D. @@ -73,11 +72,11 @@ def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_p ) # The number of chunks of scale level 0 - num_chunks_scale0 = count_chunks(image.scale0 if isinstance(image, DataTree) else image) + num_chunks_scale0 = count_chunks(image.scale0 if isinstance(image, xr.DataTree) else image) # The total number of chunks of all scale levels num_chunks_all_scales = ( sum(count_chunks(pyramid) for pyramid in image.children.values()) - if isinstance(image, DataTree) + if isinstance(image, xr.DataTree) else count_chunks(image) ) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 261b0b89..297c23d8 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1,7 +1,8 @@ import os import tempfile +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable +from typing import Any import dask.dataframe as dd import numpy as np diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 4b5b218c..173e38ce 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -3,10 +3,11 @@ import os import re import tempfile +from collections.abc import Callable from functools import partial from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Callable +from typing import Any import dask.array.core import dask.dataframe as dd @@ -16,13 +17,12 @@ from anndata import AnnData from dask.array.core import from_array from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from shapely.io import to_ragged_array from spatial_image import to_spatial_image -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike diff --git a/tests/transformations/test_transformations.py b/tests/transformations/test_transformations.py index 94651f66..ece2a8a3 100644 --- a/tests/transformations/test_transformations.py +++ b/tests/transformations/test_transformations.py @@ -557,7 +557,7 @@ def test_transform_coordinates(): DataArray(manual0, coords={"points": range(2), "dim": ["x", "y", "z"]}), DataArray(manual1, coords={"points": range(2), "dim": ["x", "y", "z"]}), ] - for t, e in zip(transformaions, expected): + for t, e in zip(transformaions, expected, strict=True): transformed = t._transform_coordinates(coords) xarray.testing.assert_allclose(transformed, e) @@ -577,7 +577,7 @@ def _assert_sequence_transformations_equal_up_to_intermediate_coordinate_systems if outer_sequence: assert t0.input_coordinate_system.name == t1.input_coordinate_system.name assert t0.output_coordinate_system.name == t1.output_coordinate_system.name - for sub0, sub1 in zip(t0.transformations, t1.transformations): + for sub0, sub1 in zip(t0.transformations, t1.transformations, strict=True): if isinstance(sub0, NgffSequence): assert isinstance(sub1, NgffSequence) _assert_sequence_transformations_equal_up_to_intermediate_coordinate_systems_names_and_units( diff --git a/tests/utils/test_element_utils.py b/tests/utils/test_element_utils.py index 077ffaa8..86e75887 100644 --- a/tests/utils/test_element_utils.py +++ b/tests/utils/test_element_utils.py @@ -3,8 +3,7 @@ import dask_image.ndinterp import pytest import xarray -from datatree import DataTree -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata._utils import unpad_raster from spatialdata.models import get_model diff --git a/tests/utils/test_testing.py b/tests/utils/test_testing.py index 4b101969..a181c87f 100644 --- a/tests/utils/test_testing.py +++ b/tests/utils/test_testing.py @@ -2,8 +2,7 @@ import numpy as np import pytest -from datatree import DataTree -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata import SpatialData, deepcopy from spatialdata.models import (