Skip to content

Commit

Permalink
Selection for GeoDataFrame engine in plotting routines (#987)
Browse files Browse the repository at this point in the history
* add engine, fix project

* add test case and comments

* initial updates to poly and lc

* fix comment and update tests

* update to_geodataframe docstring

* update edge plot

* add default clabel for edge plot

* update docstrings in geometry functions

* update to_geodataframe docstring to warn about split polygon projections

* remove unused parameter

* update call after removed unused argument

* remove commented out bit
  • Loading branch information
philipc2 authored Oct 10, 2024
1 parent e491f57 commit 691999b
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 56 deletions.
42 changes: 22 additions & 20 deletions test/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import uxarray as ux
import holoviews as hv


from unittest import TestCase
from pathlib import Path
Expand Down Expand Up @@ -44,42 +46,33 @@ def test_face_centered_data(self):
uxds = ux.open_dataset(gridfile_mpas, gridfile_mpas)

for backend in ['matplotlib', 'bokeh']:

uxds['bottomDepth'].plot(backend=backend)

uxds['bottomDepth'].plot.polygons(backend=backend)

uxds['bottomDepth'].plot.points(backend=backend)

uxds['bottomDepth'].plot.rasterize(method='polygon',
backend=backend)
assert(isinstance(uxds['bottomDepth'].plot(backend=backend), hv.DynamicMap))
assert(isinstance(uxds['bottomDepth'].plot.polygons(backend=backend), hv.DynamicMap))
assert(isinstance(uxds['bottomDepth'].plot.points(backend=backend), hv.Points))

def test_face_centered_remapped_dim(self):
"""Tests execution of plotting method on a data variable whose
dimension needed to be re-mapped."""
uxds = ux.open_dataset(gridfile_ne30, datafile_ne30)

for backend in ['matplotlib', 'bokeh']:
assert(isinstance(uxds['psi'].plot(backend=backend), hv.DynamicMap))
assert(isinstance(uxds['psi'].plot.polygons(backend=backend), hv.DynamicMap))
assert(isinstance(uxds['psi'].plot.points(backend=backend), hv.Points))

uxds['psi'].plot(backend=backend)

uxds['psi'].plot.polygons(backend=backend)

uxds['psi'].plot.points(backend=backend)

uxds['psi'].plot.rasterize(method='polygon', backend=backend)

def test_node_centered_data(self):
"""Tests execution of plotting methods on node-centered data."""

uxds = ux.open_dataset(gridfile_geoflow, datafile_geoflow)

for backend in ['matplotlib', 'bokeh']:
uxds['v1'][0][0].plot(backend=backend)
assert(isinstance(uxds['v1'][0][0].plot(backend=backend), hv.Points))

uxds['v1'][0][0].plot.points(backend=backend)
assert(isinstance(uxds['v1'][0][0].plot.points(backend=backend), hv.Points))

assert(isinstance(uxds['v1'][0][0].topological_mean(destination='face').plot.polygons(backend=backend), hv.DynamicMap))

uxds['v1'][0][0].topological_mean(destination='face').plot.polygons(backend=backend)


def test_clabel(self):
Expand All @@ -88,9 +81,18 @@ def test_clabel(self):
uxds = ux.open_dataset(gridfile_geoflow, datafile_geoflow)

raster_no_clabel = uxds['v1'][0][0].plot.rasterize(method='point')

raster_with_clabel = uxds['v1'][0][0].plot.rasterize(method='point', clabel='Foo')

def test_engine(self):
uxds = ux.open_dataset(gridfile_mpas, gridfile_mpas)
_plot_sp = uxds['bottomDepth'].plot.polygons(rasterize=True, engine='spatialpandas')
_plot_gp = uxds['bottomDepth'].plot.polygons(rasterize=True, engine='geopandas')

assert isinstance(_plot_sp, hv.DynamicMap)
assert isinstance(_plot_gp, hv.DynamicMap)



class TestXarrayMethods(TestCase):

def test_dataset(self):
Expand Down
9 changes: 5 additions & 4 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ def to_geodataframe(
self,
periodic_elements: Optional[str] = "exclude",
projection: Optional[ccrs.Projection] = None,
project: Optional[bool] = False,
cache: Optional[bool] = True,
override: Optional[bool] = False,
engine: Optional[str] = "spatialpandas",
exclude_antimeridian: Optional[bool] = None,
**kwargs,
):
"""Constructs a ``GeoDataFrame`` consisting of polygons representing
the faces of the current ``Grid`` with a face-centered data variable
Expand All @@ -178,7 +178,8 @@ def to_geodataframe(
- 'split': Periodic elements will be identified and split using the ``antimeridian`` package
- 'ignore': No processing will be applied to periodic elements.
projection: ccrs.Projection, optional
Geographic projection used to transform polygons
Geographic projection used to transform polygons. Only supported when periodic_elements is set to
'ignore' or 'exclude'
cache: bool, optional
Flag used to select whether to cache the computed GeoDataFrame
override: bool, optional
Expand All @@ -191,7 +192,7 @@ def to_geodataframe(
Returns
-------
gdf : spatialpandas.GeoDataFrame
gdf : spatialpandas.GeoDataFrame or geopandas.GeoDataFrame
The output ``GeoDataFrame`` with a filled out "geometry" column of polygons and a data column with the
same name as the ``UxDataArray`` (or named ``var`` if no name exists)
"""
Expand All @@ -207,7 +208,7 @@ def to_geodataframe(
gdf, non_nan_polygon_indices = self.uxgrid.to_geodataframe(
periodic_elements=periodic_elements,
projection=projection,
project=project,
project=kwargs.get("project", True),
cache=cache,
override=override,
exclude_antimeridian=exclude_antimeridian,
Expand Down
35 changes: 20 additions & 15 deletions uxarray/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _correct_central_longitude(node_lon, node_lat, projection):

def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project, engine):
"""Converts the faces of a ``Grid`` into a ``spatialpandas.GeoDataFrame``
with a geometry column of polygons."""
or ``geopandas.GeoDataFrame`` with a geometry column of polygons."""

node_lon, node_lat, central_longitude = _correct_central_longitude(
grid.node_lon.values, grid.node_lat.values, projection
Expand Down Expand Up @@ -214,9 +214,8 @@ def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project,
gdf = _build_geodataframe_with_antimeridian(
polygon_shells,
projected_polygon_shells,
projection,
antimeridian_face_indices,
engine=geopandas,
engine=engine,
)
elif periodic_elements == "ignore":
if engine == "geopandas":
Expand Down Expand Up @@ -248,8 +247,9 @@ def _grid_to_polygon_geodataframe(grid, periodic_elements, projection, project,
def _build_geodataframe_without_antimeridian(
polygon_shells, projected_polygon_shells, antimeridian_face_indices, engine
):
"""Builds a ``spatialpandas.GeoDataFrame`` excluding any faces that cross
the antimeridian."""
"""Builds a ``spatialpandas.GeoDataFrame`` or
``geopandas.GeoDataFrame``excluding any faces that cross the
antimeridian."""
if projected_polygon_shells is not None:
# use projected shells if a projection is applied
shells_without_antimeridian = np.delete(
Expand All @@ -276,12 +276,11 @@ def _build_geodataframe_without_antimeridian(
def _build_geodataframe_with_antimeridian(
polygon_shells,
projected_polygon_shells,
projection,
antimeridian_face_indices,
engine,
):
"""Builds a ``spatialpandas.GeoDataFrame`` including any faces that cross
the antimeridian."""
"""Builds a ``spatialpandas.GeoDataFrame`` or ``geopandas.GeoDataFrame``
including any faces that cross the antimeridian."""
polygons = _build_corrected_shapely_polygons(
polygon_shells, projected_polygon_shells, antimeridian_face_indices
)
Expand Down Expand Up @@ -425,7 +424,8 @@ def _grid_to_matplotlib_polycollection(
# Handle unsupported configuration: splitting periodic elements with projection
if periodic_elements == "split" and projection is not None:
raise ValueError(
"Projections are not supported when splitting periodic elements.'"
"Explicitly projecting lines is not supported. Please pass in your projection"
"using the 'transform' parameter"
)

# Correct the central longitude and build polygon shells
Expand Down Expand Up @@ -533,7 +533,7 @@ def _grid_to_matplotlib_polycollection(
return PolyCollection(polygon_shells, **kwargs), []


def _get_polygons(grid, periodic_elements, projection=None):
def _get_polygons(grid, periodic_elements, projection=None, apply_projection=True):
# Correct the central longitude if projection is provided
node_lon, node_lat, central_longitude = _correct_central_longitude(
grid.node_lon.values, grid.node_lat.values, projection
Expand All @@ -552,7 +552,7 @@ def _get_polygons(grid, periodic_elements, projection=None):
)

# If projection is provided, create the projected polygon shells
if projection:
if projection and apply_projection:
projected_polygon_shells = _build_polygon_shells(
node_lon,
node_lat,
Expand Down Expand Up @@ -625,8 +625,14 @@ def _grid_to_matplotlib_linecollection(
):
"""Constructs and returns a ``matplotlib.collections.LineCollection``"""

if periodic_elements == "split" and projection is not None:
apply_projection = False
else:
apply_projection = True

# do not explicitly project when splitting elements
polygons, central_longitude, _, _ = _get_polygons(
grid, periodic_elements, projection
grid, periodic_elements, projection, apply_projection
)

# Convert polygons to line segments for the LineCollection
Expand All @@ -639,14 +645,13 @@ def _grid_to_matplotlib_linecollection(
else:
lines.append(np.array(boundary.coords))

# Set default transform if not provided
if "transform" not in kwargs:
if projection is None:
# Set default transform if one is not provided not provided
if projection is None or not apply_projection:
kwargs["transform"] = ccrs.PlateCarree(central_longitude=central_longitude)
else:
kwargs["transform"] = projection

# Return a LineCollection of the line segments
return LineCollection(lines, **kwargs)


Expand Down
17 changes: 7 additions & 10 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,13 +1635,13 @@ def to_geodataframe(
self,
periodic_elements: Optional[str] = "exclude",
projection: Optional[ccrs.Projection] = None,
project: Optional[bool] = False,
cache: Optional[bool] = True,
override: Optional[bool] = False,
engine: Optional[str] = "spatialpandas",
exclude_antimeridian: Optional[bool] = None,
return_non_nan_polygon_indices: Optional[bool] = False,
exclude_nan_polygons: Optional[bool] = True,
**kwargs,
):
"""Constructs a ``GeoDataFrame`` consisting of polygons representing
the faces of the current ``Grid``
Expand All @@ -1661,7 +1661,8 @@ def to_geodataframe(
- 'split': Periodic elements will be identified and split using the ``antimeridian`` package
- 'ignore': No processing will be applied to periodic elements.
projection: ccrs.Projection, optional
Geographic projection used to transform polygons
Geographic projection used to transform polygons. Only supported when periodic_elements is set to
'ignore' or 'exclude'
cache: bool, optional
Flag used to select whether to cache the computed GeoDataFrame
override: bool, optional
Expand All @@ -1679,7 +1680,7 @@ def to_geodataframe(
Returns
-------
gdf : spatialpandas.GeoDataFrame
gdf : spatialpandas.GeoDataFrame or geopandas.GeoDataFrame
The output ``GeoDataFrame`` with a filled out "geometry" column of polygons.
"""

Expand All @@ -1688,6 +1689,9 @@ def to_geodataframe(
f"Invalid engine. Expected one of ['spatialpandas', 'geopandas'] but received {engine}"
)

# if project is false, projection is only used for determining central coordinates
project = kwargs.get("project", True)

if projection and project:
if periodic_elements == "split":
raise ValueError(
Expand Down Expand Up @@ -1871,13 +1875,6 @@ def to_linecollection(
f"Invalid value for 'periodic_elements'. Expected one of ['ignore', 'exclude', 'split'] but received: {periodic_elements}"
)

if projection is not None:
if periodic_elements == "split":
raise ValueError(
"Setting ``periodic_elements='split'`` is not supported when a "
"projection is provided."
)

if self._line_collection_cached_parameters["line_collection"] is not None:
if (
self._line_collection_cached_parameters["periodic_elements"]
Expand Down
36 changes: 29 additions & 7 deletions uxarray/plot/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ def face_centers(self, backend=None, **kwargs):

face_centers.__doc__ = face_coords.__doc__

def edges(self, periodic_elements="exclude", backend=None, **kwargs):
def edges(
self,
periodic_elements="exclude",
backend=None,
engine="spatialpandas",
**kwargs,
):
"""Plots the edges of a Grid.
This function plots the edges of the grid as geographical paths using `hvplot`.
Expand All @@ -182,6 +188,8 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs):
- "split": Split periodic elements.
backend : str or None, optional
Plotting backend to use. One of ['matplotlib', 'bokeh']. Equivalent to running holoviews.extension(backend)
engine: str, optional
Engine to use for GeoDataFrame construction. One of ['spatialpandas', 'geopandas']
**kwargs : dict
Additional keyword arguments passed to `hvplot.paths`. These can include:
- "rasterize" (bool): Whether to rasterize the plot (default: False),
Expand All @@ -195,7 +203,6 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs):
gdf.hvplot.paths : hvplot.paths
A paths plot of the edges of the unstructured grid
"""

uxarray.plot.utils.backend.assign(backend)

if "rasterize" not in kwargs:
Expand All @@ -212,8 +219,11 @@ def edges(self, periodic_elements="exclude", backend=None, **kwargs):
kwargs["crs"] = ccrs.PlateCarree(central_longitude=central_longitude)

gdf = self._uxgrid.to_geodataframe(
periodic_elements=periodic_elements, projection=kwargs.get("projection")
)[["geometry"]]
periodic_elements=periodic_elements,
projection=kwargs.get("projection"),
engine=engine,
project=False,
)

return gdf.hvplot.paths(geo=True, **kwargs)

Expand Down Expand Up @@ -260,8 +270,15 @@ def __getattr__(self, name: str) -> Any:
else:
raise AttributeError(f"Unsupported Plotting Method: '{name}'")

def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs):
"""Generate a shaded polygon plot of a face-centered data variable.
def polygons(
self,
periodic_elements="exclude",
backend=None,
engine="spatialpandas",
*args,
**kwargs,
):
"""Generated a shaded polygon plot.
This function plots the faces of an unstructured grid shaded with a face-centered data variable using hvplot.
It allows for rasterization, projection settings, and labeling of the data variable to be
Expand All @@ -278,6 +295,8 @@ def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs):
- "ignore": Include periodic elements without any corrections
backend : str or None, optional
Plotting backend to use. One of ['matplotlib', 'bokeh']. Equivalent to running holoviews.extension(backend)
engine: str, optional
Engine to use for GeoDataFrame construction. One of ['spatialpandas', 'geopandas']
*args : tuple
Additional positional arguments to be passed to `hvplot.polygons`.
**kwargs : dict
Expand Down Expand Up @@ -309,7 +328,10 @@ def polygons(self, periodic_elements="exclude", backend=None, *args, **kwargs):
kwargs["crs"] = ccrs.PlateCarree(central_longitude=central_longitude)

gdf = self._uxda.to_geodataframe(
periodic_elements=periodic_elements, projection=kwargs.get("projection")
periodic_elements=periodic_elements,
projection=kwargs.get("projection"),
engine=engine,
project=False,
)

return gdf.hvplot.polygons(
Expand Down

0 comments on commit 691999b

Please sign in to comment.