Skip to content

Commit

Permalink
Merge branch 'main' into sdata_attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMarconato committed Nov 13, 2024
2 parents 37ebf37 + 42f7b6a commit dc4f508
Show file tree
Hide file tree
Showing 54 changed files with 357 additions and 379 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ jobs:
strategy:
fail-fast: false
matrix:
python: ["3.9", "3.12"]
python: ["3.10", "3.12"]
os: [ubuntu-latest]
include:
- os: macos-latest
python: "3.9"
python: "3.10"
- os: macos-latest
python: "3.12"
pip-flags: "--pre"
Expand Down
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
python_version = 3.9
python_version = 3.10
plugins = numpy.typing.mypy_plugin

ignore_errors = False
Expand Down
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,32 @@ fail_fast: false
default_language_version:
python: python3
default_stages:
- commit
- push
- pre-commit
- pre-push
minimum_pre_commit_version: 2.16.0
ci:
skip: []
repos:
- repo: https://github.com/psf/black
rev: 24.8.0
rev: 24.10.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
- repo: https://github.com/asottile/blacken-docs
rev: 1.18.0
rev: 1.19.1
hooks:
- id: blacken-docs
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
additional_dependencies: [numpy, types-requests]
exclude: tests/|docs/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.8
rev: v0.7.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
22 changes: 19 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,28 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [0.2.4] - xxxx-xx-xx
## [0.2.6] - TBD

### Fixed

- Updated deprecated default stages of `pre-commit` #771

## [0.2.5] - 2024-06-11

### Fixed

- Incompatibility issues due to newest release of `multiscale-spatial-image` #760

## [0.2.4] - 2024-06-11

### Major

- Enable vectorization of `bounding_box_query` for all `SpatialData` elements. #699

### Minor

- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems`
- Added `get_pyramid_levels()` utils API
- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems` #714
- Added `get_pyramid_levels()` utils API #719
- Improved ergonomics of `concatenate()` when element names are non-unique #720
- Improved performance of writing images with multiscales #577

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
Submodule notebooks updated 176 files
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ maintainers = [
urls.Documentation = "https://spatialdata.scverse.org/en/latest"
urls.Source = "https://github.com/scverse/spatialdata.git"
urls.Home-page = "https://github.com/scverse/spatialdata.git"
requires-python = ">=3.9"
requires-python = ">=3.10, <3.13" # include 3.13 once multiscale-spatial-image conflicts are resolved
dynamic= [
"version" # allow version to be set by git tags
]
Expand All @@ -28,8 +28,9 @@ dependencies = [
"dask>=2024.4.1",
"fsspec<=2023.6",
"geopandas>=0.14",
"multiscale_spatial_image>=1.0.0",
"multiscale_spatial_image>=2.0.1",
"networkx",
"numba",
"numpy",
"ome_zarr>=0.8.4",
"pandas",
Expand All @@ -42,8 +43,7 @@ dependencies = [
"scikit-image",
"scipy",
"typing_extensions>=4.8.0",
"xarray",
"xarray-datatree",
"xarray>=2024.10.0",
"xarray-schema",
"xarray-spatial>=0.3.5",
"zarr",
Expand All @@ -69,6 +69,7 @@ test = [
"pytest",
"pytest-cov",
"pytest-mock",
"torch",
]
torch = [
"torch"
Expand Down Expand Up @@ -100,7 +101,7 @@ filterwarnings = [

[tool.black]
line-length = 120
target-version = ['py39']
target-version = ['py310']
include = '\.pyi?$'
exclude = '''
(
Expand Down Expand Up @@ -145,7 +146,7 @@ exclude = [
"setup.py",
]
line-length = 120
target-version = "py39"
target-version = "py310"

[tool.ruff.lint]
ignore = [
Expand Down
3 changes: 1 addition & 2 deletions src/spatialdata/_core/_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from dask.array.core import Array as DaskArray
from dask.array.core import from_array
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from xarray import DataArray
from xarray import DataArray, DataTree

from spatialdata._core.spatialdata import SpatialData
from spatialdata.models._utils import SpatialElement
Expand Down
7 changes: 3 additions & 4 deletions src/spatialdata/_core/centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import pandas as pd
import xarray as xr
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from shapely import MultiPolygon, Point, Polygon
from xarray import DataArray
from xarray import DataArray, DataTree

from spatialdata._core.operations.transform import transform
from spatialdata.models import get_axes_names
Expand Down Expand Up @@ -86,7 +85,7 @@ def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame:
centroids[label_value] += count * i.values.item()

all_labels_values, all_labels_counts = da.unique(xdata.data, return_counts=True)
all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute()))
all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute(), strict=True))
for label_value in centroids:
centroids[label_value] /= all_labels[label_value]
centroids = dict(sorted(centroids.items(), key=lambda x: x[0]))
Expand Down Expand Up @@ -132,7 +131,7 @@ def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
if isinstance(first_geometry, Point):
xy = e.geometry.get_coordinates().values
else:
assert isinstance(first_geometry, (Polygon, MultiPolygon)), (
assert isinstance(first_geometry, Polygon | MultiPolygon), (
f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or"
f" Polygons/MultiPolygons. Found {type(first_geometry)} instead."
)
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _concatenate_tables(
raise ValueError("`instance_key` must be specified if tables have different instance keys")

tables_l = []
for table_region_key, table_instance_key, table in zip(region_keys, instance_keys, tables):
for table_region_key, table_instance_key, table in zip(region_keys, instance_keys, tables, strict=True):
rename_dict = {}
if table_region_key != region_key:
rename_dict[table_region_key] = region_key
Expand Down Expand Up @@ -247,7 +247,7 @@ def _fix_ensure_unique_element_names(
tables[new_name] = table
tables_by_sdata.append(tables)
sdatas_fixed = []
for elements, tables in zip(elements_by_sdata, tables_by_sdata):
for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True):
sdata = SpatialData.init_from_elements(elements, tables=tables)
sdatas_fixed.append(sdata)
return sdatas_fixed
11 changes: 5 additions & 6 deletions src/spatialdata/_core/data_extent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import numpy as np
import pandas as pd
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from shapely import MultiPolygon, Point, Polygon
from xarray import DataArray
from xarray import DataArray, DataTree

from spatialdata._core.operations.transform import transform
from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -71,7 +70,7 @@ def _get_extent_of_polygons_multipolygons(
-------
The bounding box description.
"""
assert isinstance(shapes.geometry.iloc[0], (Polygon, MultiPolygon))
assert isinstance(shapes.geometry.iloc[0], Polygon | MultiPolygon)
axes = get_axes_names(shapes)
bounds = shapes["geometry"].bounds
return {ax: (bounds[f"min{ax}"].min(), bounds[f"max{ax}"].max()) for ax in axes}
Expand Down Expand Up @@ -201,7 +200,7 @@ def _(
new_max_coordinates_dict: dict[str, list[float]] = defaultdict(list)
mask = [has_images, has_labels, has_points, has_shapes]
include_spatial_elements = ["images", "labels", "points", "shapes"]
include_spatial_elements = [i for (i, v) in zip(include_spatial_elements, mask) if v]
include_spatial_elements = [i for (i, v) in zip(include_spatial_elements, mask, strict=True) if v]

if elements is None: # to shut up ruff
elements = []
Expand All @@ -217,7 +216,7 @@ def _(
assert isinstance(transformations, dict)
coordinate_systems = list(transformations.keys())
if coordinate_system in coordinate_systems:
if isinstance(element_obj, (DaskDataFrame, GeoDataFrame)):
if isinstance(element_obj, DaskDataFrame | GeoDataFrame):
extent = get_extent(element_obj, coordinate_system=coordinate_system, exact=exact)
else:
extent = get_extent(element_obj, coordinate_system=coordinate_system)
Expand Down Expand Up @@ -254,7 +253,7 @@ def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription:
first_geometry = e_temp["geometry"].iloc[0]
if isinstance(first_geometry, Point):
return _get_extent_of_circles(e)
assert isinstance(first_geometry, (Polygon, MultiPolygon))
assert isinstance(first_geometry, Polygon | MultiPolygon)
return _get_extent_of_polygons_multipolygons(e)


Expand Down
5 changes: 2 additions & 3 deletions src/spatialdata/_core/operations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from typing import TYPE_CHECKING

from datatree import DataTree
from xarray import DataArray
from xarray import DataArray, DataTree

from spatialdata.models import SpatialElement

Expand Down Expand Up @@ -115,7 +114,7 @@ def transform_to_data_extent(
}

for _, element_name, element in sdata_raster.gen_spatial_elements():
if isinstance(element, (DataArray, DataTree)):
if isinstance(element, DataArray | DataTree):
rasterized = rasterize(
element,
axes=data_extent_axes,
Expand Down
5 changes: 2 additions & 3 deletions src/spatialdata/_core/operations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import numpy as np
import pandas as pd
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from scipy import sparse
from shapely import Point
from xarray import DataArray
from xarray import DataArray, DataTree
from xrspatial import zonal_stats

from spatialdata._core.operations._utils import _parse_element
Expand Down Expand Up @@ -246,7 +245,7 @@ def _create_sdata_from_table_and_shapes(
table = TableModel.parse(table, region=shapes_name, region_key=region_key, instance_key=instance_key)

# labels case, needs conversion from str to int
if isinstance(shapes, (DataArray, DataTree)):
if isinstance(shapes, DataArray | DataTree):
table.obs[instance_key] = table.obs[instance_key].astype(int)

if deepcopy:
Expand Down
7 changes: 3 additions & 4 deletions src/spatialdata/_core/operations/map.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

import dask.array as da
from dask.array.overlap import coerce_depth
from datatree import DataTree
from xarray import DataArray
from xarray import DataArray, DataTree

from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims
from spatialdata.transformations import get_transformation
Expand Down
3 changes: 1 addition & 2 deletions src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import numpy as np
from dask.array import Array as DaskArray
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from shapely import Point
from xarray import DataArray
from xarray import DataArray, DataTree

from spatialdata._core.operations._utils import _parse_element
from spatialdata._core.operations.transform import transform
Expand Down
6 changes: 3 additions & 3 deletions src/spatialdata/_core/operations/rasterize_bins.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def rasterize_bins(
"""
element = sdata[bins]
table = sdata.tables[table_name]
if not isinstance(element, (GeoDataFrame, DaskDataFrame)):
if not isinstance(element, GeoDataFrame | DaskDataFrame):
raise ValueError("The bins should be a GeoDataFrame or a DaskDataFrame.")

_, region_key, instance_key = get_table_keys(table)
Expand All @@ -94,7 +94,7 @@ def rasterize_bins(
keys = ([value_key] if isinstance(value_key, str) else value_key) if value_key is not None else table.var_names

if (value_key is None or any(key in table.var_names for key in keys)) and not isinstance(
table.X, (csc_matrix, np.ndarray)
table.X, csc_matrix | np.ndarray
):
raise ValueError(
"To speed up bins rasterization, the X matrix in the table, when sparse, should be a csc_matrix matrix. "
Expand Down Expand Up @@ -162,7 +162,7 @@ def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike:
sub_x = sub_df.geometry.x.values
sub_y = sub_df.geometry.y.values
else:
assert isinstance(sub_df.iloc[0].geometry, (Polygon, MultiPolygon))
assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon)
sub_x = sub_df.centroid.x
sub_y = sub_df.centroid.y
else:
Expand Down
12 changes: 7 additions & 5 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
import numpy as np
from dask.array.core import Array as DaskArray
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from shapely import Point
from xarray import DataArray
from xarray import DataArray, Dataset, DataTree

from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
Expand Down Expand Up @@ -161,13 +160,13 @@ def _set_transformation_for_transformed_elements(
assert to_coordinate_system is None

to_prepend: BaseTransformation | None
if isinstance(element, (DataArray, DataTree)):
if isinstance(element, DataArray | DataTree):
if maintain_positioning:
assert raster_translation is not None
to_prepend = Sequence([raster_translation, transformation.inverse()])
else:
to_prepend = raster_translation
elif isinstance(element, (GeoDataFrame, DaskDataFrame)):
elif isinstance(element, GeoDataFrame | DaskDataFrame):
assert raster_translation is None
to_prepend = transformation.inverse() if maintain_positioning else Identity()
else:
Expand Down Expand Up @@ -393,8 +392,11 @@ def _(
raster_translation = raster_translation_single_scale
# we set a dummy empty dict for the transformation that will be replaced with the correct transformation for
# each scale later in this function, when calling set_transformation()
transformed_dict[k] = DataArray(transformed_dask, dims=xdata.dims, name=xdata.name, attrs={TRANSFORM_KEY: {}})
transformed_dict[k] = Dataset(
{"image": DataArray(transformed_dask, dims=xdata.dims, name=xdata.name, attrs={TRANSFORM_KEY: {}})}
)
if channel_names is not None:
# This expression returns a dataset now.
transformed_dict[k] = transformed_dict[k].assign_coords(c=channel_names)

# mypy thinks that schema could be ShapesModel, PointsModel, ...
Expand Down
Loading

0 comments on commit dc4f508

Please sign in to comment.