Skip to content

Commit

Permalink
feat(IDW): Add search radius parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
msorvoja authored and nmaarnio committed Dec 4, 2024
1 parent 5ca25b3 commit 890fcf2
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 109 deletions.
9 changes: 8 additions & 1 deletion eis_toolkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,7 @@ def idw_interpolation_cli(
pixel_size: float = None,
extent: Tuple[float, float, float, float] = (None, None, None, None),
power: float = 2.0,
search_radius: Optional[float] = None,
):
"""Apply inverse distance weighting (IDW) interpolation to input vector file."""
from eis_toolkit.exceptions import InvalidParameterValueException
Expand All @@ -1985,7 +1986,13 @@ def idw_interpolation_cli(
with rasterio.open(base_raster) as raster:
profile = raster.profile.copy()

out_image = idw(geodataframe=geodataframe, target_column=target_column, raster_profile=profile, power=power)
out_image = idw(
geodataframe=geodataframe,
target_column=target_column,
raster_profile=profile,
power=power,
search_radius=search_radius,
)
typer.echo("Progress: 75%")

profile["count"] = 1
Expand Down
45 changes: 35 additions & 10 deletions eis_toolkit/vector_processing/idw_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import geopandas as gpd
import numpy as np
from beartype import beartype
from beartype.typing import Union
from beartype.typing import Optional, Union
from rasterio import profiles, transform

from eis_toolkit.exceptions import EmptyDataFrameException, InvalidParameterValueException, NonMatchingCrsException
Expand All @@ -18,6 +18,7 @@ def _idw_interpolation(
raster_height: int,
raster_transform: transform.Affine,
power: Number,
search_radius: Optional[Number],
) -> np.ndarray:

points = np.array(geodataframe.geometry.apply(lambda geom: (geom.x, geom.y)).tolist())
Expand All @@ -34,26 +35,47 @@ def _idw_interpolation(
y = np.linspace(grid_y_min, grid_y_max, raster_height)
y = y[::-1].reshape(-1, 1)

interpolated_values = _idw_core(points[:, 0], points[:, 1], values, x, y, power)
interpolated_values = _idw_core(points[:, 0], points[:, 1], values, x, y, power, search_radius)
interpolated_values = interpolated_values.reshape(raster_height, raster_width)

return interpolated_values


# Distance calculations
def _idw_core(x, y, z, xi, yi: np.ndarray, power: Number) -> np.ndarray:
def _idw_core(
x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
xi: np.ndarray,
yi: np.ndarray,
power: Number,
search_radius: Optional[Number],
) -> np.ndarray:
over = np.zeros((len(yi), len(xi)))
under = np.zeros((len(yi), len(xi)))
for n in range(len(x)):
dist = np.hypot(xi - x[n], yi - y[n])
# Add a small epsilon to avoid division by zero
dist = np.where(dist == 0, 1e-12, dist)
dist = dist**power

over += z[n] / dist
under += 1.0 / dist
# Exclude points outside search radius
if search_radius is not None:
mask = dist <= search_radius
if not np.any(mask):
continue

# Add a small epsilon to avoid division by zero
dist = np.where(dist[mask] == 0, 1e-12, dist[mask]) ** power

over[mask] += z[n] / dist
under[mask] += 1.0 / dist

else:
# Add a small epsilon to avoid division by zero
dist = np.where(dist == 0, 1e-12, dist) ** power

over += z[n] / dist
under += 1.0 / dist

interpolated_values = over / under
interpolated_values = np.divide(over, under, out=np.full_like(over, np.nan), where=under != 0)
return interpolated_values


Expand All @@ -63,6 +85,7 @@ def idw(
target_column: str,
raster_profile: Union[profiles.Profile, dict],
power: Number = 2,
search_radius: Optional[Number] = None,
) -> np.ndarray:
"""Calculate inverse distance weighted (IDW) interpolation.
Expand All @@ -73,6 +96,8 @@ def idw(
crs, transform, width and height.
power: The value for determining the rate at which the weights decrease. As power increases,
the weights for distant points decrease rapidly. Defaults to 2.
search_radius: The search radius within which to consider points for interpolation.
If None, all points are used.
Returns:
Numpy array containing the interpolated values.
Expand All @@ -97,7 +122,7 @@ def idw(
raster_transform = raster_profile.get("transform")

interpolated_values = _idw_interpolation(
geodataframe, target_column, raster_width, raster_height, raster_transform, power
geodataframe, target_column, raster_width, raster_height, raster_transform, power, search_radius
)

return interpolated_values
179 changes: 81 additions & 98 deletions notebooks/testing_idw.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Empty file.
19 changes: 19 additions & 0 deletions tests/vector_processing/idw_interpolation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

test_dir = Path(__file__).parent.parent
idw_test_data = test_dir.joinpath("data/remote/interpolating/idw_test_data.tif")
idw_radius_test_data = test_dir.joinpath("data/remote/interpolating/idw_radius_test_data.tif")


@pytest.fixture
Expand Down Expand Up @@ -74,6 +75,24 @@ def test_validated_points_with_extent(validated_points, raster_profile):
np.testing.assert_almost_equal(interpolated_values, external_values, decimal=2)


def test_validated_points_with_radius(validated_points, raster_profile):
"""Test IDW with search radius."""
target_column = "random_number"
interpolated_values = idw(
geodataframe=validated_points,
target_column=target_column,
raster_profile=raster_profile,
power=2,
search_radius=0.5,
)
assert target_column in validated_points.columns

with rasterio.open(idw_radius_test_data) as src:
external_values = src.read(1)

np.testing.assert_almost_equal(interpolated_values, external_values, decimal=2)


def test_invalid_column(test_points, raster_profile):
"""Test invalid column GeoDataFrame."""
target_column = "not-in-data-column"
Expand Down

0 comments on commit 890fcf2

Please sign in to comment.