Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace stacking gradient search with resample_blocks variant #626

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 35 additions & 276 deletions pyresample/gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,256 +53,26 @@ def GradientSearchResampler(source_geo_def, target_geo_def):

def create_gradient_search_resampler(source_geo_def, target_geo_def):
"""Create a gradient search resampler."""
if isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition):
if (is_area_to_area(source_geo_def, target_geo_def) or
is_swath_to_area(source_geo_def, target_geo_def) or
is_area_to_swath(source_geo_def, target_geo_def)):
return ResampleBlocksGradientSearchResampler(source_geo_def, target_geo_def)
elif isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition):
return StackingGradientSearchResampler(source_geo_def, target_geo_def)
raise NotImplementedError


@da.as_gufunc(signature='(),()->(),()')
def transform(x_coords, y_coords, src_prj=None, dst_prj=None):
"""Calculate projection coordinates."""
transformer = pyproj.Transformer.from_crs(src_prj, dst_prj)
return transformer.transform(x_coords, y_coords)
def is_area_to_area(source_geo_def, target_geo_def):
"""Check if source is area and target is area."""
return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)


class StackingGradientSearchResampler(BaseResampler):
"""Resample using gradient search based bilinear interpolation, using stacking for dask processing."""

def __init__(self, source_geo_def, target_geo_def):
"""Init GradientResampler."""
super().__init__(source_geo_def, target_geo_def)
import warnings
warnings.warn("You are using the Gradient Search Resampler, which is still EXPERIMENTAL.", stacklevel=2)
self.use_input_coords = None
self._src_dst_filtered = False
self.prj = None
self.src_x = None
self.src_y = None
self.src_slices = None
self.dst_x = None
self.dst_y = None
self.dst_slices = None
self.src_gradient_xl = None
self.src_gradient_xp = None
self.src_gradient_yl = None
self.src_gradient_yp = None
self.dst_polys = {}
self.dst_mosaic_locations = None
self.coverage_status = None

def _get_projection_coordinates(self, datachunks):
"""Get projection coordinates."""
if self.use_input_coords is None:
try:
self.src_x, self.src_y = self.source_geo_def.get_proj_coords(
chunks=datachunks)
src_crs = self.source_geo_def.crs
self.use_input_coords = True
except AttributeError:
self.src_x, self.src_y = self.source_geo_def.get_lonlats(
chunks=datachunks)
src_crs = pyproj.CRS.from_string("+proj=longlat")
self.use_input_coords = False
try:
self.dst_x, self.dst_y = self.target_geo_def.get_proj_coords(
chunks=CHUNK_SIZE)
dst_crs = self.target_geo_def.crs
except AttributeError as err:
if self.use_input_coords is False:
raise NotImplementedError('Cannot resample lon/lat to lon/lat with gradient search.') from err
self.dst_x, self.dst_y = self.target_geo_def.get_lonlats(
chunks=CHUNK_SIZE)
dst_crs = pyproj.CRS.from_string("+proj=longlat")
if self.use_input_coords:
self.dst_x, self.dst_y = transform(
self.dst_x, self.dst_y,
src_prj=dst_crs, dst_prj=src_crs)
self.prj = pyproj.Proj(self.source_geo_def.crs)
else:
self.src_x, self.src_y = transform(
self.src_x, self.src_y,
src_prj=src_crs, dst_prj=dst_crs)
self.prj = pyproj.Proj(self.target_geo_def.crs)

def _get_prj_poly(self, geo_def):
# - None if out of Earth Disk
# - False is SwathDefinition
if isinstance(geo_def, SwathDefinition):
return False
try:
poly = get_polygon(self.prj, geo_def)
except (NotImplementedError, ValueError): # out-of-earth disk area or any valid projected boundary coordinates
poly = None
return poly

def _get_src_poly(self, src_y_start, src_y_end, src_x_start, src_x_end):
"""Get bounding polygon for source chunk."""
geo_def = self.source_geo_def[src_y_start:src_y_end,
src_x_start:src_x_end]
return self._get_prj_poly(geo_def)

def _get_dst_poly(self, idx,
dst_x_start, dst_x_end,
dst_y_start, dst_y_end):
"""Get target chunk polygon."""
dst_poly = self.dst_polys.get(idx, None)
if dst_poly is None:
geo_def = self.target_geo_def[dst_y_start:dst_y_end,
dst_x_start:dst_x_end]
dst_poly = self._get_prj_poly(geo_def)
self.dst_polys[idx] = dst_poly
return dst_poly

def get_chunk_mappings(self):
"""Map source and target chunks together if they overlap."""
src_y_chunks, src_x_chunks = self.src_x.chunks
dst_y_chunks, dst_x_chunks = self.dst_x.chunks

coverage_status = []
src_slices, dst_slices = [], []
dst_mosaic_locations = []

src_x_start = 0
for src_x_step in src_x_chunks:
src_x_end = src_x_start + src_x_step
src_y_start = 0
for src_y_step in src_y_chunks:
src_y_end = src_y_start + src_y_step
# Get source chunk polygon
src_poly = self._get_src_poly(src_y_start, src_y_end,
src_x_start, src_x_end)

dst_x_start = 0
for x_step_number, dst_x_step in enumerate(dst_x_chunks):
dst_x_end = dst_x_start + dst_x_step
dst_y_start = 0
for y_step_number, dst_y_step in enumerate(dst_y_chunks):
dst_y_end = dst_y_start + dst_y_step
# Get destination chunk polygon
dst_poly = self._get_dst_poly((x_step_number, y_step_number),
dst_x_start, dst_x_end,
dst_y_start, dst_y_end)

covers = check_overlap(src_poly, dst_poly)

coverage_status.append(covers)
src_slices.append((src_y_start, src_y_end,
src_x_start, src_x_end))
dst_slices.append((dst_y_start, dst_y_end,
dst_x_start, dst_x_end))
dst_mosaic_locations.append((x_step_number, y_step_number))

dst_y_start = dst_y_end
dst_x_start = dst_x_end
src_y_start = src_y_end
src_x_start = src_x_end

self.src_slices = src_slices
self.dst_slices = dst_slices
self.dst_mosaic_locations = dst_mosaic_locations
self.coverage_status = coverage_status

def _filter_data(self, data, is_src=True, add_dim=False):
"""Filter unused chunks from the given array."""
if add_dim:
if data.ndim not in [2, 3]:
raise NotImplementedError('Gradient search resampling only '
'supports 2D or 3D arrays.')
if data.ndim == 2:
data = data[np.newaxis, :, :]

data_out = []
for i, covers in enumerate(self.coverage_status):
if covers:
if is_src:
y_start, y_end, x_start, x_end = self.src_slices[i]
else:
y_start, y_end, x_start, x_end = self.dst_slices[i]
try:
val = data[:, y_start:y_end, x_start:x_end]
except IndexError:
val = data[y_start:y_end, x_start:x_end]
else:
val = None
data_out.append(val)

return data_out

def _get_gradients(self):
"""Get gradients in X and Y directions."""
self.src_gradient_xl, self.src_gradient_xp = np.gradient(
self.src_x, axis=[0, 1])
self.src_gradient_yl, self.src_gradient_yp = np.gradient(
self.src_y, axis=[0, 1])

def _filter_src_dst(self):
"""Filter source and target chunks."""
self.src_x = self._filter_data(self.src_x)
self.src_y = self._filter_data(self.src_y)
self.src_gradient_yl = self._filter_data(self.src_gradient_yl)
self.src_gradient_yp = self._filter_data(self.src_gradient_yp)
self.src_gradient_xl = self._filter_data(self.src_gradient_xl)
self.src_gradient_xp = self._filter_data(self.src_gradient_xp)
self.dst_x = self._filter_data(self.dst_x, is_src=False)
self.dst_y = self._filter_data(self.dst_y, is_src=False)
self._src_dst_filtered = True

def compute(self, data, fill_value=None, **kwargs):
"""Resample the given data using gradient search algorithm."""
if 'bands' in data.dims:
datachunks = data.sel(bands=data.coords['bands'][0]).chunks
else:
datachunks = data.chunks
data_dims = data.dims
data_coords = data.coords

self._get_projection_coordinates(datachunks)

if self.src_gradient_xl is None:
self._get_gradients()
if self.coverage_status is None:
self.get_chunk_mappings()
if not self._src_dst_filtered:
self._filter_src_dst()

data = self._filter_data(data.data, add_dim=True)

res = parallel_gradient_search(data,
self.src_x, self.src_y,
self.dst_x, self.dst_y,
self.src_gradient_xl,
self.src_gradient_xp,
self.src_gradient_yl,
self.src_gradient_yp,
self.dst_mosaic_locations,
self.dst_slices,
**kwargs)

coords = _fill_in_coords(self.target_geo_def, data_coords, data_dims)

if fill_value is not None:
res = da.where(np.isnan(res), fill_value, res)
if res.ndim > len(data_dims):
res = res.squeeze()

res = xr.DataArray(res, dims=data_dims, coords=coords)
return res
def is_swath_to_area(source_geo_def, target_geo_def):
"""Check if source is swath and target is area."""
return isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition)


def check_overlap(src_poly, dst_poly):
"""Check if the two polygons overlap."""
# swath definition case
if dst_poly is False or src_poly is False:
covers = True
# area / area case
elif dst_poly is not None and src_poly is not None:
covers = src_poly.intersects(dst_poly)
# out of earth disk case
else:
covers = False
return covers
def is_area_to_swath(source_geo_def, target_geo_def):
"""Check if source is area and targed is swath."""
return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, SwathDefinition)


def _gradient_resample_data(src_data, src_x, src_y,
Expand Down Expand Up @@ -367,30 +137,6 @@ def _check_input_coordinates(dst_x, dst_y,
raise ValueError("Target arrays should all have the same shape")


def get_border_lonlats(geo_def: AreaDefinition):
"""Get the border x- and y-coordinates."""
if geo_def.is_geostationary:
lon_b, lat_b = get_geostationary_bounding_box_in_lonlats(geo_def, 3600)
else:
lons, lats = geo_def.get_boundary_lonlats()
lon_b = np.concatenate((lons.side1, lons.side2, lons.side3, lons.side4))
lat_b = np.concatenate((lats.side1, lats.side2, lats.side3, lats.side4))

return lon_b, lat_b


def get_polygon(prj, geo_def):
"""Get border polygon from area definition in projection *prj*."""
lon_b, lat_b = get_border_lonlats(geo_def)
x_borders, y_borders = prj(lon_b, lat_b)
boundary = [(x_borders[i], y_borders[i]) for i in range(len(x_borders))
if np.isfinite(x_borders[i]) and np.isfinite(y_borders[i])]
poly = Polygon(boundary)
if np.isfinite(poly.area) and poly.area > 0.0:
return poly
return None


def parallel_gradient_search(data, src_x, src_y, dst_x, dst_y,
src_gradient_xl, src_gradient_xp,
src_gradient_yl, src_gradient_yp,
Expand Down Expand Up @@ -456,7 +202,10 @@ def _concatenate_chunks(chunks):


def _fill_in_coords(target_geo_def, data_coords, data_dims):
x_coord, y_coord = target_geo_def.get_proj_vectors()
try:
x_coord, y_coord = target_geo_def.get_proj_vectors()
except AttributeError:
return None
coords = []
for key in data_dims:
if key == 'x':
Expand Down Expand Up @@ -489,10 +238,10 @@ class ResampleBlocksGradientSearchResampler(BaseResampler):

def __init__(self, source_geo_def, target_geo_def):
"""Init GradientResampler."""
if isinstance(target_geo_def, SwathDefinition):
raise NotImplementedError("Cannot resample to a SwathDefinition.")
if isinstance(source_geo_def, SwathDefinition):
source_geo_def.lons = source_geo_def.lons.persist()
source_geo_def.lats = source_geo_def.lats.persist()
super().__init__(source_geo_def, target_geo_def)
logger.debug("/!\\ Instantiating an experimental GradientSearch resampler /!\\")
self.indices_xy = None

def precompute(self, **kwargs):
Expand Down Expand Up @@ -590,14 +339,21 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar
def _get_coordinates_in_same_projection(source_area, target_area):
try:
src_x, src_y = source_area.get_proj_coords()
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
except AttributeError as err:
raise NotImplementedError("Cannot resample from Swath for now.") from err

work_crs = source_area.crs
except AttributeError:
# source is a swath definition, use target crs instead
lons, lats = source_area.get_lonlats()
src_x, src_y = da.compute(lons, lats)
trans = pyproj.Transformer.from_crs(source_area.crs, target_area.crs, always_xy=True)
src_x, src_y = trans.transform(src_x, src_y)
work_crs = target_area.crs
transformer = pyproj.Transformer.from_crs(target_area.crs, work_crs, always_xy=True)
try:
dst_x, dst_y = transformer.transform(*target_area.get_proj_coords())
except AttributeError as err:
raise NotImplementedError("Cannot resample to Swath for now.") from err
except AttributeError:
# target is a swath definition
lons, lats = target_area.get_lonlats()
dst_x, dst_y = transformer.transform(*da.compute(lons, lats))
src_gradient_xl, src_gradient_xp = np.gradient(src_x, axis=[0, 1])
src_gradient_yl, src_gradient_yp = np.gradient(src_y, axis=[0, 1])
return (dst_x, dst_y), (src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp), (src_x, src_y)
Expand All @@ -610,6 +366,9 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info=
weight_l, l_start = np.modf(y_indices.clip(0, data.shape[-2] - 1))
weight_p, p_start = np.modf(x_indices.clip(0, data.shape[-1] - 1))

weight_l = weight_l.astype(data.dtype)
weight_p = weight_p.astype(data.dtype)

l_start = l_start.astype(int)
p_start = p_start.astype(int)
l_end = np.clip(l_start + 1, 1, data.shape[-2] - 1)
Expand Down
8 changes: 4 additions & 4 deletions pyresample/gradient/_gradient_search.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ cdef inline void bil(const data_type[:, :, :] data, int l0, int p0, float_index
p_b = min(p0 + 1, pmax)
w_p = dp
for i in range(z_size):
res[i] = ((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] +
(1 - w_l) * w_p * data[i, l_a, p_b] +
w_l * (1 - w_p) * data[i, l_b, p_a] +
w_l * w_p * data[i, l_b, p_b])
res[i] = <data_type>((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] +
(1 - w_l) * w_p * data[i, l_a, p_b] +
w_l * (1 - w_p) * data[i, l_b, p_a] +
w_l * w_p * data[i, l_b, p_b])


@cython.boundscheck(False)
Expand Down
7 changes: 3 additions & 4 deletions pyresample/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def resample_blocks(func, src_area, src_arrays, dst_area,
fill_value: Desired value for any invalid values in the output array
kwargs: any other keyword arguments that will be passed on to func.

Returns:
A dask array, chunked as dst_area, containing the resampled data.


Principle of operations:
Resample_blocks works by iterating over chunks on the dst_area domain. For each chunk, the corresponding slice
Expand All @@ -235,10 +238,6 @@ def resample_blocks(func, src_area, src_arrays, dst_area,


"""
if dst_area == src_area:
raise ValueError("Source and destination areas are identical."
" Should you be running `map_blocks` instead of `resample_blocks`?")

name = _create_dask_name(name, func,
src_area, src_arrays,
dst_area, dst_arrays,
Expand Down
Loading
Loading