diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 62536ce38..df64b6f65 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -4,8 +4,9 @@ import numpy as np -from typing import TYPE_CHECKING, Optional, Hashable, Literal +from typing import TYPE_CHECKING, Callable, Optional, Hashable, Literal +from uxarray.constants import GRID_DIMS from uxarray.formatting_html import array_repr from html import escape @@ -971,8 +972,6 @@ def isel(self, ignore_grid=False, *args, **kwargs): > uxda.subset(n_node=[1, 2, 3]) """ - from uxarray.constants import GRID_DIMS - if any(grid_dim in kwargs for grid_dim in GRID_DIMS) and not ignore_grid: # slicing a grid-dimension through Grid object @@ -1029,3 +1028,117 @@ def _slice_from_grid(self, sliced_grid): dims=self.dims, attrs=self.attrs, ) + + def neighborhood_filter( + self, + func: Callable = np.mean, + r: float = 1.0, + ) -> UxDataArray: + """Apply neighborhood filter + Parameters: + ----------- + func: Callable, default=np.mean + Apply this function to neighborhood + r : float, default=1. + Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. + Returns: + -------- + destination_data : np.ndarray + Filtered data. + """ + + if self._face_centered(): + data_mapping = "face centers" + elif self._node_centered(): + data_mapping = "nodes" + elif self._edge_centered(): + data_mapping = "edge centers" + else: + raise ValueError( + "Data_mapping is not face, node, or edge. Could not define data_mapping." + ) + + # reconstruct because the cached tree could be built from + # face centers, edge centers or nodes. + tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True) + + coordinate_system = tree.coordinate_system + + if coordinate_system == "spherical": + if data_mapping == "nodes": + lon, lat = ( + self.uxgrid.node_lon.values, + self.uxgrid.node_lat.values, + ) + elif data_mapping == "face centers": + lon, lat = ( + self.uxgrid.face_lon.values, + self.uxgrid.face_lat.values, + ) + elif data_mapping == "edge centers": + lon, lat = ( + self.uxgrid.edge_lon.values, + self.uxgrid.edge_lat.values, + ) + else: + raise ValueError( + f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {data_mapping}" + ) + + dest_coords = np.vstack((lon, lat)).T + + elif coordinate_system == "cartesian": + if data_mapping == "nodes": + x, y, z = ( + self.uxgrid.node_x.values, + self.uxgrid.node_y.values, + self.uxgrid.node_z.values, + ) + elif data_mapping == "face centers": + x, y, z = ( + self.uxgrid.face_x.values, + self.uxgrid.face_y.values, + self.uxgrid.face_z.values, + ) + elif data_mapping == "edge centers": + x, y, z = ( + self.uxgrid.edge_x.values, + self.uxgrid.edge_y.values, + self.uxgrid.edge_z.values, + ) + else: + raise ValueError( + f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {data_mapping}" + ) + + dest_coords = np.vstack((x, y, z)).T + + else: + raise ValueError( + f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}" + ) + + neighbor_indices = tree.query_radius(dest_coords, r=r) + + # Construct numpy array for filtered variable. + destination_data = np.empty(self.data.shape) + + # Assert last dimension is a GRID dimension. + assert self.dims[-1] in GRID_DIMS, ( + f"expected last dimension of uxDataArray {self.data.dims[-1]} " + f"to be one of {GRID_DIMS}" + ) + # Apply function to indices on last axis. + for i, idx in enumerate(neighbor_indices): + if len(idx): + destination_data[..., i] = func(self.data[..., idx]) + + # Construct UxDataArray for filtered variable. + uxda_filter = self._copy() + + uxda_filter.data = destination_data + + return uxda_filter diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index 23843f9ee..21fc8f76a 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -5,8 +5,9 @@ import sys -from typing import Optional, IO +from typing import Callable, Optional, IO +from uxarray.constants import GRID_DIMS from uxarray.grid import Grid from uxarray.core.dataarray import UxDataArray @@ -337,3 +338,38 @@ def to_array(self) -> UxDataArray: xarr = super().to_array() return UxDataArray(xarr, uxgrid=self.uxgrid) + + def neighborhood_filter( + self, + func: Callable = np.mean, + r: float = 1.0, + ): + """Neighborhood function implementation for ``UxDataset``. + Parameters + --------- + func : Callable = np.mean + Apply this function to neighborhood + r : float, default=1. + Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. + """ + + destination_uxds = self._copy() + # Loop through uxDataArrays in uxDataset + for var_name in self.data_vars: + uxda = self[var_name] + + # Skip if uxDataArray has no GRID dimension. + grid_dims = [dim for dim in uxda.dims if dim in GRID_DIMS] + if len(grid_dims) == 0: + continue + + # Put GRID dimension last for UxDataArray.neighborhood_filter. + remember_dim_order = uxda.dims + uxda = uxda.transpose(..., grid_dims[0]) + # Filter uxDataArray. + uxda = uxda.neighborhood_filter(func, r) + # Restore old dimension order. + destination_uxds[var_name] = uxda.transpose(*remember_dim_order) + + return destination_uxds