Skip to content

Commit

Permalink
fix models docstrings (#530)
Browse files Browse the repository at this point in the history
* fix models docstrings

* added force_2d() utils
  • Loading branch information
LucaMarconato authored Mar 30, 2024
1 parent 3b29f0b commit e9f57a1
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 13 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [0.1.2] - 2024-xx-xx
## [0.1.3] - 2024-xx-xx

## [0.1.2] - 2024-03-30

### Minor

- Made `get_channels()` public.
- Added utils `force_2d()` to force 3D shapes to 2D (this is a temporary solution until `.force_2d()` is available in `geopandas`).

## [0.1.1] - 2024-03-28

Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ The elements (building-blocks) that consitute `SpatialData`.
points_geopandas_to_dask_dataframe
points_dask_dataframe_to_geopandas
get_channels
force_2d
```

## Transformations
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
X,
Y,
Z,
force_2d,
get_axes_names,
get_channels,
get_spatial_axes,
Expand Down Expand Up @@ -50,4 +51,5 @@
"check_target_region_column_symmetry",
"get_table_keys",
"get_channels",
"force_2d",
]
39 changes: 39 additions & 0 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import dask.dataframe as dd
import geopandas
import numpy as np
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from multiscale_spatial_image import MultiscaleSpatialImage
from shapely.geometry import MultiPolygon, Point, Polygon
from spatial_image import SpatialImage

from spatialdata._logging import logger
Expand Down Expand Up @@ -296,3 +298,40 @@ def _(data: MultiscaleSpatialImage) -> list[Any]:
if len(channels) > 1:
raise ValueError("TODO")
return list(next(iter(channels)))


def force_2d(gdf: GeoDataFrame) -> None:
"""
Force the geometries of a shapes object GeoDataFrame to be 2D by modifying the geometries in place.
Geopandas introduced a method called `force_2d()` to drop the z dimension.
Unfortunately, this feature, as of geopandas == 0.14.3, is still not released.
Similarly, the recently released shapely >= 2.0.3 implemented `force_2d()`, but currently there are installation
errors.
A similar function has been developed in When `.force_2d()`
Parameters
----------
gdf
GeoDataFrame with 2D or 3D geometries
"""
new_shapes = []
any_3d = False
for shape in gdf.geometry:
if shape.has_z:
any_3d = True
if isinstance(shape, Point):
new_shape = Point(shape.x, shape.y)
elif isinstance(shape, Polygon):
new_shape = Polygon(np.array(shape.exterior.coords.xy).T)
elif isinstance(shape, MultiPolygon):
new_shape = MultiPolygon([Polygon(np.array(p.exterior.coords.xy).T) for p in shape.geoms])
else:
raise ValueError(f"Unsupported geometry type: {type(shape)}")
new_shapes.append(new_shape)
else:
new_shapes.append(shape)
if any_3d:
gdf.geometry = new_shapes
20 changes: 13 additions & 7 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def validate(cls, data: GeoDataFrame) -> None:
if n != 2:
warnings.warn(
f"The geometry column of the GeoDataFrame has {n} dimensions, while 2 is expected. Please consider "
"discarding the third dimension as it could led to unexpected behaviors.",
"discarding the third dimension as it could led to unexpected behaviors. To achieve so, you can use"
" `.force_2d()` if you are using `geopandas > 0.14.3, otherwise you can use `force_2d()` from "
"`spatialdata.models`.",
UserWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -512,24 +514,28 @@ def parse(cls, data: Any, **kwargs: Any) -> DaskDataFrame:
Data to parse:
- If :class:`numpy.ndarray`, an `annotation` :class:`pandas.DataFrame`
must be provided, as well as the `feature_key` in the `annotation`. Furthermore,
can be provided, as well as a `feature_key` column in the `annotation` dataframe. Furthermore,
:class:`numpy.ndarray` is assumed to have shape `(n_points, axes)`, with `axes` being
"x", "y" and optionally "z".
- If :class:`pandas.DataFrame`, a `coordinates` mapping must be provided
with key as *valid axes* and value as column names in dataframe.
- If :class:`pandas.DataFrame`, a `coordinates` mapping can be provided
with key as *valid axes* ('x', 'y', 'z') and value as column names in dataframe. If the dataframe
already has columns named 'x', 'y' and 'z', the mapping can be omitted.
annotation
Annotation dataframe. Only if `data` is :class:`numpy.ndarray`. If data is an array, the index of the
annotations will be used as the index of the parsed points.
coordinates
Mapping of axes names (keys) to column names (valus) in `data`. Only if `data` is
:class:`pandas.DataFrame`. Example: {'x': 'my_x_column', 'y': 'my_y_column'}.
If not provided and `data` is :class:`pandas.DataFrame`, and `x`, `y` and optinally `z` are column names,
If not provided and `data` is :class:`pandas.DataFrame`, and `x`, `y` and optionally `z` are column names,
then they will be used as coordinates.
feature_key
Feature key in `annotation` or `data`.
Optional, feature key in `annotation` or `data`. Example use case: gene id categorical column describing the
gene identity of each point.
instance_key
Instance key in `annotation` or `data`.
Optional, instance key in `annotation` or `data`. Example use case: cell id column, describing which cell
a point belongs to. This argument is likely going to be deprecated:
https://github.com/scverse/spatialdata/issues/503.
transformations
Transformations of points.
kwargs
Expand Down
47 changes: 42 additions & 5 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,30 @@
from multiscale_spatial_image import MultiscaleSpatialImage
from numpy.random import default_rng
from pandas.api.types import is_categorical_dtype
from shapely.geometry import MultiPolygon, Point, Polygon
from shapely.io import to_ragged_array
from spatial_image import SpatialImage, to_spatial_image
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.models import (
from spatialdata.models._utils import (
force_2d,
points_dask_dataframe_to_geopandas,
points_geopandas_to_dask_dataframe,
validate_axis_name,
)
from spatialdata.models.models import (
Image2DModel,
Image3DModel,
Labels2DModel,
Labels3DModel,
PointsModel,
RasterSchema,
ShapesModel,
TableModel,
get_axes_names,
get_model,
points_dask_dataframe_to_geopandas,
points_geopandas_to_dask_dataframe,
)
from spatialdata.models._utils import validate_axis_name
from spatialdata.models.models import RasterSchema
from spatialdata.testing import assert_elements_are_identical
from spatialdata.transformations._utils import (
_set_transformations,
_set_transformations_xarray,
Expand Down Expand Up @@ -427,3 +432,35 @@ def test_model_polygon_z():
match="The geometry column of the GeoDataFrame has 3 dimensions, while 2 is expected. Please consider",
):
_ = ShapesModel.parse(gpd.GeoDataFrame(geometry=[polygon]))


def test_force2d():
# let's create a shapes object (circles) constructed from 3D points (let's mix 2D and 3D)
circles_3d = ShapesModel.parse(GeoDataFrame({"geometry": (Point(1, 1, 1), Point(2, 2)), "radius": [2, 2]}))

polygon1 = Polygon([(0, 0, 0), (1, 0, 0), (1, 1, 0)])
polygon2 = Polygon([(0, 0), (1, 0), (1, 1)])

# let's create a shapes object (polygons) constructed from 3D polygons
polygons_3d = ShapesModel.parse(GeoDataFrame({"geometry": [polygon1, polygon2]}))

# let's create a shapes object (multipolygons) constructed from 3D multipolygons
multipolygons_3d = ShapesModel.parse(GeoDataFrame({"geometry": [MultiPolygon([polygon1, polygon2])]}))

force_2d(circles_3d)
force_2d(polygons_3d)
force_2d(multipolygons_3d)

expected_circles_2d = ShapesModel.parse(GeoDataFrame({"geometry": (Point(1, 1), Point(2, 2)), "radius": [2, 2]}))
expected_polygons_2d = ShapesModel.parse(
GeoDataFrame({"geometry": [Polygon([(0, 0), (1, 0), (1, 1)]), Polygon([(0, 0), (1, 0), (1, 1)])]})
)
expected_multipolygons_2d = ShapesModel.parse(
GeoDataFrame(
{"geometry": [MultiPolygon([Polygon([(0, 0), (1, 0), (1, 1)]), Polygon([(0, 0), (1, 0), (1, 1)])])]}
)
)

assert_elements_are_identical(circles_3d, expected_circles_2d)
assert_elements_are_identical(polygons_3d, expected_polygons_2d)
assert_elements_are_identical(multipolygons_3d, expected_multipolygons_2d)

0 comments on commit e9f57a1

Please sign in to comment.