-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apply function to points within circular neighborhood #941
base: main
Are you sure you want to change the base?
Changes from all commits
75d832e
b850bc4
8ec0193
5605949
70ba961
0c7bc1e
d6d8a33
47b9cda
f75db0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,9 @@ | |
import numpy as np | ||
|
||
|
||
from typing import TYPE_CHECKING, Optional, Hashable, Literal | ||
from typing import TYPE_CHECKING, Callable, Optional, Hashable, Literal | ||
|
||
from uxarray.constants import GRID_DIMS | ||
from uxarray.formatting_html import array_repr | ||
|
||
from html import escape | ||
|
@@ -971,8 +972,6 @@ def isel(self, ignore_grid=False, *args, **kwargs): | |
> uxda.subset(n_node=[1, 2, 3]) | ||
""" | ||
|
||
from uxarray.constants import GRID_DIMS | ||
|
||
if any(grid_dim in kwargs for grid_dim in GRID_DIMS) and not ignore_grid: | ||
# slicing a grid-dimension through Grid object | ||
|
||
|
@@ -1029,3 +1028,117 @@ def _slice_from_grid(self, sliced_grid): | |
dims=self.dims, | ||
attrs=self.attrs, | ||
) | ||
|
||
def neighborhood_filter( | ||
self, | ||
func: Callable = np.mean, | ||
r: float = 1.0, | ||
) -> UxDataArray: | ||
"""Apply neighborhood filter | ||
Parameters: | ||
----------- | ||
func: Callable, default=np.mean | ||
Apply this function to neighborhood | ||
r : float, default=1. | ||
Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, | ||
and for cartesian coordinates, the radius is in meters. | ||
Returns: | ||
-------- | ||
destination_data : np.ndarray | ||
Filtered data. | ||
""" | ||
|
||
if self._face_centered(): | ||
data_mapping = "face centers" | ||
elif self._node_centered(): | ||
data_mapping = "nodes" | ||
elif self._edge_centered(): | ||
data_mapping = "edge centers" | ||
else: | ||
raise ValueError( | ||
"Data_mapping is not face, node, or edge. Could not define data_mapping." | ||
) | ||
|
||
# reconstruct because the cached tree could be built from | ||
# face centers, edge centers or nodes. | ||
tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True) | ||
Comment on lines
+1063
to
+1064
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably fix this logic in if self._ball_tree is None or reconstruct:
self._ball_tree = BallTree(
self,
coordinates=coordinates,
distance_metric=distance_metric,
coordinate_system=coordinate_system,
reconstruct=reconstruct,
)
else:
if coordinates != self._ball_tree._coordinates:
self._ball_tree.coordinates = coordinates The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. So, move the coordinates check to the if-clause like this? if (
self._ball_tree is None
or coordinates != self._ball_tree._coordinates
or reconstruct
):
self._ball_tree = BallTree(
self,
coordinates=coordinates,
distance_metric=distance_metric,
coordinate_system=coordinate_system,
reconstruct=reconstruct,
) What if the coordinate_system is different? Would that also require a newly constructed tree? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whatever logic is fixed in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. checking coordinate system also (coordinate_system is not a hidden variable of _ball_tree; it has no underscore): if (
self._ball_tree is None
or coordinates != self._ball_tree._coordinates
or coordinate_system != self._ball_tree.coordinate_system
or reconstruct
):
self._ball_tree = BallTree(
self,
coordinates=coordinates,
distance_metric=distance_metric,
coordinate_system=coordinate_system,
reconstruct=reconstruct,
) |
||
|
||
coordinate_system = tree.coordinate_system | ||
|
||
if coordinate_system == "spherical": | ||
if data_mapping == "nodes": | ||
lon, lat = ( | ||
self.uxgrid.node_lon.values, | ||
self.uxgrid.node_lat.values, | ||
) | ||
elif data_mapping == "face centers": | ||
lon, lat = ( | ||
self.uxgrid.face_lon.values, | ||
self.uxgrid.face_lat.values, | ||
) | ||
elif data_mapping == "edge centers": | ||
lon, lat = ( | ||
self.uxgrid.edge_lon.values, | ||
self.uxgrid.edge_lat.values, | ||
) | ||
else: | ||
raise ValueError( | ||
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " | ||
f"but received: {data_mapping}" | ||
) | ||
|
||
dest_coords = np.vstack((lon, lat)).T | ||
|
||
elif coordinate_system == "cartesian": | ||
if data_mapping == "nodes": | ||
x, y, z = ( | ||
self.uxgrid.node_x.values, | ||
self.uxgrid.node_y.values, | ||
self.uxgrid.node_z.values, | ||
) | ||
elif data_mapping == "face centers": | ||
x, y, z = ( | ||
self.uxgrid.face_x.values, | ||
self.uxgrid.face_y.values, | ||
self.uxgrid.face_z.values, | ||
) | ||
elif data_mapping == "edge centers": | ||
x, y, z = ( | ||
self.uxgrid.edge_x.values, | ||
self.uxgrid.edge_y.values, | ||
self.uxgrid.edge_z.values, | ||
) | ||
else: | ||
raise ValueError( | ||
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " | ||
f"but received: {data_mapping}" | ||
) | ||
Comment on lines
+1065
to
+1115
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use #974 's new |
||
|
||
dest_coords = np.vstack((x, y, z)).T | ||
|
||
else: | ||
raise ValueError( | ||
f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}" | ||
) | ||
|
||
neighbor_indices = tree.query_radius(dest_coords, r=r) | ||
|
||
# Construct numpy array for filtered variable. | ||
destination_data = np.empty(self.data.shape) | ||
|
||
# Assert last dimension is a GRID dimension. | ||
assert self.dims[-1] in GRID_DIMS, ( | ||
f"expected last dimension of uxDataArray {self.data.dims[-1]} " | ||
f"to be one of {GRID_DIMS}" | ||
) | ||
# Apply function to indices on last axis. | ||
for i, idx in enumerate(neighbor_indices): | ||
if len(idx): | ||
destination_data[..., i] = func(self.data[..., idx]) | ||
|
||
# Construct UxDataArray for filtered variable. | ||
uxda_filter = self._copy() | ||
|
||
uxda_filter.data = destination_data | ||
|
||
return uxda_filter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation looks great! May we move the bulk of the logic into the
uxarray.grid.neighbors
module and call that helper from here?We can keep the data-mapping checks here, and anything related to constructing and returining the final data array but the bulk of the computations would go inside a helper in the module mentioned above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to think about how to do that, but I am happy to defer to you.