From 3c072554474a7d0087e98b8c3c3396557d631ea9 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 22 Oct 2024 11:04:10 +0200 Subject: [PATCH 1/6] Replace stacking gradient search with resample_blocks variant --- pyresample/gradient/__init__.py | 243 +---------------- pyresample/gradient/_gradient_search.pyx | 8 +- pyresample/resampler.py | 7 +- pyresample/test/test_gradient.py | 332 +++++------------------ 4 files changed, 81 insertions(+), 509 deletions(-) diff --git a/pyresample/gradient/__init__.py b/pyresample/gradient/__init__.py index 41c98846..ddc3e519 100644 --- a/pyresample/gradient/__init__.py +++ b/pyresample/gradient/__init__.py @@ -53,244 +53,19 @@ def GradientSearchResampler(source_geo_def, target_geo_def): def create_gradient_search_resampler(source_geo_def, target_geo_def): """Create a gradient search resampler.""" - if isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition): + if ((isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)) or + (isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition))): return ResampleBlocksGradientSearchResampler(source_geo_def, target_geo_def) - elif isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition): - return StackingGradientSearchResampler(source_geo_def, target_geo_def) raise NotImplementedError @da.as_gufunc(signature='(),()->(),()') def transform(x_coords, y_coords, src_prj=None, dst_prj=None): """Calculate projection coordinates.""" - transformer = pyproj.Transformer.from_crs(src_prj, dst_prj) + transformer = pyproj.Transformer.from_crs(src_prj, dst_prj, always_xy=True) return transformer.transform(x_coords, y_coords) -class StackingGradientSearchResampler(BaseResampler): - """Resample using gradient search based bilinear interpolation, using stacking for dask processing.""" - - def __init__(self, source_geo_def, target_geo_def): - """Init GradientResampler.""" - super().__init__(source_geo_def, target_geo_def) - import warnings - warnings.warn("You are using the Gradient Search Resampler, which is still EXPERIMENTAL.", stacklevel=2) - self.use_input_coords = None - self._src_dst_filtered = False - self.prj = None - self.src_x = None - self.src_y = None - self.src_slices = None - self.dst_x = None - self.dst_y = None - self.dst_slices = None - self.src_gradient_xl = None - self.src_gradient_xp = None - self.src_gradient_yl = None - self.src_gradient_yp = None - self.dst_polys = {} - self.dst_mosaic_locations = None - self.coverage_status = None - - def _get_projection_coordinates(self, datachunks): - """Get projection coordinates.""" - if self.use_input_coords is None: - try: - self.src_x, self.src_y = self.source_geo_def.get_proj_coords( - chunks=datachunks) - src_crs = self.source_geo_def.crs - self.use_input_coords = True - except AttributeError: - self.src_x, self.src_y = self.source_geo_def.get_lonlats( - chunks=datachunks) - src_crs = pyproj.CRS.from_string("+proj=longlat") - self.use_input_coords = False - try: - self.dst_x, self.dst_y = self.target_geo_def.get_proj_coords( - chunks=CHUNK_SIZE) - dst_crs = self.target_geo_def.crs - except AttributeError as err: - if self.use_input_coords is False: - raise NotImplementedError('Cannot resample lon/lat to lon/lat with gradient search.') from err - self.dst_x, self.dst_y = self.target_geo_def.get_lonlats( - chunks=CHUNK_SIZE) - dst_crs = pyproj.CRS.from_string("+proj=longlat") - if self.use_input_coords: - self.dst_x, self.dst_y = transform( - self.dst_x, self.dst_y, - src_prj=dst_crs, dst_prj=src_crs) - self.prj = pyproj.Proj(self.source_geo_def.crs) - else: - self.src_x, self.src_y = transform( - self.src_x, self.src_y, - src_prj=src_crs, dst_prj=dst_crs) - self.prj = pyproj.Proj(self.target_geo_def.crs) - - def _get_prj_poly(self, geo_def): - # - None if out of Earth Disk - # - False is SwathDefinition - if isinstance(geo_def, SwathDefinition): - return False - try: - poly = get_polygon(self.prj, geo_def) - except (NotImplementedError, ValueError): # out-of-earth disk area or any valid projected boundary coordinates - poly = None - return poly - - def _get_src_poly(self, src_y_start, src_y_end, src_x_start, src_x_end): - """Get bounding polygon for source chunk.""" - geo_def = self.source_geo_def[src_y_start:src_y_end, - src_x_start:src_x_end] - return self._get_prj_poly(geo_def) - - def _get_dst_poly(self, idx, - dst_x_start, dst_x_end, - dst_y_start, dst_y_end): - """Get target chunk polygon.""" - dst_poly = self.dst_polys.get(idx, None) - if dst_poly is None: - geo_def = self.target_geo_def[dst_y_start:dst_y_end, - dst_x_start:dst_x_end] - dst_poly = self._get_prj_poly(geo_def) - self.dst_polys[idx] = dst_poly - return dst_poly - - def get_chunk_mappings(self): - """Map source and target chunks together if they overlap.""" - src_y_chunks, src_x_chunks = self.src_x.chunks - dst_y_chunks, dst_x_chunks = self.dst_x.chunks - - coverage_status = [] - src_slices, dst_slices = [], [] - dst_mosaic_locations = [] - - src_x_start = 0 - for src_x_step in src_x_chunks: - src_x_end = src_x_start + src_x_step - src_y_start = 0 - for src_y_step in src_y_chunks: - src_y_end = src_y_start + src_y_step - # Get source chunk polygon - src_poly = self._get_src_poly(src_y_start, src_y_end, - src_x_start, src_x_end) - - dst_x_start = 0 - for x_step_number, dst_x_step in enumerate(dst_x_chunks): - dst_x_end = dst_x_start + dst_x_step - dst_y_start = 0 - for y_step_number, dst_y_step in enumerate(dst_y_chunks): - dst_y_end = dst_y_start + dst_y_step - # Get destination chunk polygon - dst_poly = self._get_dst_poly((x_step_number, y_step_number), - dst_x_start, dst_x_end, - dst_y_start, dst_y_end) - - covers = check_overlap(src_poly, dst_poly) - - coverage_status.append(covers) - src_slices.append((src_y_start, src_y_end, - src_x_start, src_x_end)) - dst_slices.append((dst_y_start, dst_y_end, - dst_x_start, dst_x_end)) - dst_mosaic_locations.append((x_step_number, y_step_number)) - - dst_y_start = dst_y_end - dst_x_start = dst_x_end - src_y_start = src_y_end - src_x_start = src_x_end - - self.src_slices = src_slices - self.dst_slices = dst_slices - self.dst_mosaic_locations = dst_mosaic_locations - self.coverage_status = coverage_status - - def _filter_data(self, data, is_src=True, add_dim=False): - """Filter unused chunks from the given array.""" - if add_dim: - if data.ndim not in [2, 3]: - raise NotImplementedError('Gradient search resampling only ' - 'supports 2D or 3D arrays.') - if data.ndim == 2: - data = data[np.newaxis, :, :] - - data_out = [] - for i, covers in enumerate(self.coverage_status): - if covers: - if is_src: - y_start, y_end, x_start, x_end = self.src_slices[i] - else: - y_start, y_end, x_start, x_end = self.dst_slices[i] - try: - val = data[:, y_start:y_end, x_start:x_end] - except IndexError: - val = data[y_start:y_end, x_start:x_end] - else: - val = None - data_out.append(val) - - return data_out - - def _get_gradients(self): - """Get gradients in X and Y directions.""" - self.src_gradient_xl, self.src_gradient_xp = np.gradient( - self.src_x, axis=[0, 1]) - self.src_gradient_yl, self.src_gradient_yp = np.gradient( - self.src_y, axis=[0, 1]) - - def _filter_src_dst(self): - """Filter source and target chunks.""" - self.src_x = self._filter_data(self.src_x) - self.src_y = self._filter_data(self.src_y) - self.src_gradient_yl = self._filter_data(self.src_gradient_yl) - self.src_gradient_yp = self._filter_data(self.src_gradient_yp) - self.src_gradient_xl = self._filter_data(self.src_gradient_xl) - self.src_gradient_xp = self._filter_data(self.src_gradient_xp) - self.dst_x = self._filter_data(self.dst_x, is_src=False) - self.dst_y = self._filter_data(self.dst_y, is_src=False) - self._src_dst_filtered = True - - def compute(self, data, fill_value=None, **kwargs): - """Resample the given data using gradient search algorithm.""" - if 'bands' in data.dims: - datachunks = data.sel(bands=data.coords['bands'][0]).chunks - else: - datachunks = data.chunks - data_dims = data.dims - data_coords = data.coords - - self._get_projection_coordinates(datachunks) - - if self.src_gradient_xl is None: - self._get_gradients() - if self.coverage_status is None: - self.get_chunk_mappings() - if not self._src_dst_filtered: - self._filter_src_dst() - - data = self._filter_data(data.data, add_dim=True) - - res = parallel_gradient_search(data, - self.src_x, self.src_y, - self.dst_x, self.dst_y, - self.src_gradient_xl, - self.src_gradient_xp, - self.src_gradient_yl, - self.src_gradient_yp, - self.dst_mosaic_locations, - self.dst_slices, - **kwargs) - - coords = _fill_in_coords(self.target_geo_def, data_coords, data_dims) - - if fill_value is not None: - res = da.where(np.isnan(res), fill_value, res) - if res.ndim > len(data_dims): - res = res.squeeze() - - res = xr.DataArray(res, dims=data_dims, coords=coords) - return res - - def check_overlap(src_poly, dst_poly): """Check if the two polygons overlap.""" # swath definition case @@ -491,8 +266,10 @@ def __init__(self, source_geo_def, target_geo_def): """Init GradientResampler.""" if isinstance(target_geo_def, SwathDefinition): raise NotImplementedError("Cannot resample to a SwathDefinition.") + if isinstance(source_geo_def, SwathDefinition): + source_geo_def.lons = source_geo_def.lons.persist() + source_geo_def.lats = source_geo_def.lats.persist() super().__init__(source_geo_def, target_geo_def) - logger.debug("/!\\ Instantiating an experimental GradientSearch resampler /!\\") self.indices_xy = None def precompute(self, **kwargs): @@ -590,11 +367,11 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar def _get_coordinates_in_same_projection(source_area, target_area): try: src_x, src_y = source_area.get_proj_coords() - transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True) except AttributeError as err: - raise NotImplementedError("Cannot resample from Swath for now.") from err - + lons, lats = source_area.get_lonlats() + src_x, src_y = da.compute(lons, lats) try: + transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True) dst_x, dst_y = transformer.transform(*target_area.get_proj_coords()) except AttributeError as err: raise NotImplementedError("Cannot resample to Swath for now.") from err @@ -618,7 +395,7 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info= res = ((1 - weight_l) * (1 - weight_p) * data[..., l_start, p_start] + (1 - weight_l) * weight_p * data[..., l_start, p_end] + weight_l * (1 - weight_p) * data[..., l_end, p_start] + - weight_l * weight_p * data[..., l_end, p_end]) + weight_l * weight_p * data[..., l_end, p_end]).astype(data.dtype) res = np.where(mask, fill_value, res) return res diff --git a/pyresample/gradient/_gradient_search.pyx b/pyresample/gradient/_gradient_search.pyx index e291866e..d0127cac 100644 --- a/pyresample/gradient/_gradient_search.pyx +++ b/pyresample/gradient/_gradient_search.pyx @@ -80,10 +80,10 @@ cdef inline void bil(const data_type[:, :, :] data, int l0, int p0, float_index p_b = min(p0 + 1, pmax) w_p = dp for i in range(z_size): - res[i] = ((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] + - (1 - w_l) * w_p * data[i, l_a, p_b] + - w_l * (1 - w_p) * data[i, l_b, p_a] + - w_l * w_p * data[i, l_b, p_b]) + res[i] = ((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] + + (1 - w_l) * w_p * data[i, l_a, p_b] + + w_l * (1 - w_p) * data[i, l_b, p_a] + + w_l * w_p * data[i, l_b, p_b]) @cython.boundscheck(False) diff --git a/pyresample/resampler.py b/pyresample/resampler.py index 6a5ff3e8..5883852f 100644 --- a/pyresample/resampler.py +++ b/pyresample/resampler.py @@ -214,6 +214,9 @@ def resample_blocks(func, src_area, src_arrays, dst_area, fill_value: Desired value for any invalid values in the output array kwargs: any other keyword arguments that will be passed on to func. + Returns: + A dask array, chunked as dst_area, containing the resampled data. + Principle of operations: Resample_blocks works by iterating over chunks on the dst_area domain. For each chunk, the corresponding slice @@ -235,10 +238,6 @@ def resample_blocks(func, src_area, src_arrays, dst_area, """ - if dst_area == src_area: - raise ValueError("Source and destination areas are identical." - " Should you be running `map_blocks` instead of `resample_blocks`?") - name = _create_dask_name(name, func, src_area, src_arrays, dst_area, dst_arrays, diff --git a/pyresample/test/test_gradient.py b/pyresample/test/test_gradient.py index 5b915be2..90983847 100644 --- a/pyresample/test/test_gradient.py +++ b/pyresample/test/test_gradient.py @@ -35,259 +35,6 @@ from pyresample.gradient import ResampleBlocksGradientSearchResampler -class TestOGradientResampler: - """Test case for the gradient resampling.""" - - def setup_method(self): - """Set up the test case.""" - from pyresample.gradient import StackingGradientSearchResampler - self.src_area = AreaDefinition('dst', 'dst area', None, - {'ellps': 'WGS84', 'h': '35785831', 'proj': 'geos'}, - 100, 100, - (5550000.0, 5550000.0, -5550000.0, -5550000.0)) - self.src_swath = SwathDefinition(*self.src_area.get_lonlats()) - self.dst_area = AreaDefinition('euro40', 'euro40', None, - {'proj': 'stere', 'lon_0': 14.0, - 'lat_0': 90.0, 'lat_ts': 60.0, - 'ellps': 'bessel'}, - 102, 102, - (-2717181.7304994687, -5571048.14031214, - 1378818.2695005313, -1475048.1403121399)) - self.dst_swath = SwathDefinition(*self.dst_area.get_lonlats()) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message=".*which is still EXPERIMENTAL.*", category=UserWarning) - self.resampler = StackingGradientSearchResampler(self.src_area, self.dst_area) - self.swath_resampler = StackingGradientSearchResampler(self.src_swath, self.dst_area) - self.area_to_swath_resampler = StackingGradientSearchResampler(self.src_area, self.dst_swath) - - def test_get_projection_coordinates_area_to_area(self): - """Check that the coordinates are initialized, for area -> area.""" - assert self.resampler.prj is None - self.resampler._get_projection_coordinates((10, 10)) - cdst_x = self.resampler.dst_x.compute() - cdst_y = self.resampler.dst_y.compute() - assert np.allclose(np.min(cdst_x), -2022632.1675016289) - assert np.allclose(np.max(cdst_x), 2196052.591296284) - assert np.allclose(np.min(cdst_y), 3517933.413092212) - assert np.allclose(np.max(cdst_y), 5387038.893400168) - assert self.resampler.use_input_coords - assert self.resampler.prj is not None - - def test_get_projection_coordinates_swath_to_area(self): - """Check that the coordinates are initialized, for swath -> area.""" - assert self.swath_resampler.prj is None - self.swath_resampler._get_projection_coordinates((10, 10)) - cdst_x = self.swath_resampler.dst_x.compute() - cdst_y = self.swath_resampler.dst_y.compute() - assert np.allclose(np.min(cdst_x), -2697103.29912692) - assert np.allclose(np.max(cdst_x), 1358739.8381279823) - assert np.allclose(np.min(cdst_y), -5550969.708939591) - assert np.allclose(np.max(cdst_y), -1495126.5716846888) - assert self.swath_resampler.use_input_coords is False - assert self.swath_resampler.prj is not None - - def test_get_gradients(self): - """Test that coordinate gradients are computed correctly.""" - self.resampler._get_projection_coordinates((10, 10)) - assert self.resampler.src_gradient_xl is None - self.resampler._get_gradients() - assert self.resampler.src_gradient_xl.compute().max() == 0.0 - assert self.resampler.src_gradient_xp.compute().max() == -111000.0 - assert self.resampler.src_gradient_yl.compute().max() == 111000.0 - assert self.resampler.src_gradient_yp.compute().max() == 0.0 - - def test_get_chunk_mappings(self): - """Test that chunk overlap, and source and target slices are correct.""" - chunks = (10, 10) - num_chunks = np.prod(chunks) - self.resampler._get_projection_coordinates(chunks) - self.resampler._get_gradients() - assert self.resampler.coverage_status is None - self.resampler.get_chunk_mappings() - # 8 source chunks overlap the target area - covered_src_chunks = np.array([38, 39, 48, 49, 58, 59, 68, 69]) - res = np.where(self.resampler.coverage_status)[0] - assert np.all(res == covered_src_chunks) - # All *num_chunks* should have values in the lists - assert len(self.resampler.coverage_status) == num_chunks - assert len(self.resampler.src_slices) == num_chunks - assert len(self.resampler.dst_slices) == num_chunks - assert len(self.resampler.dst_mosaic_locations) == num_chunks - # There's only one output chunk, and the covered source chunks - # should have destination locations of (0, 0) - res = np.array(self.resampler.dst_mosaic_locations)[covered_src_chunks] - assert all([all(loc == (0, 0)) for loc in list(res)]) - - def test_get_src_poly_area(self): - """Test defining source chunk polygon for AreaDefinition.""" - chunks = (10, 10) - self.resampler._get_projection_coordinates(chunks) - self.resampler._get_gradients() - poly = self.resampler._get_src_poly(0, 40, 0, 40) - assert np.allclose(poly.area, 12365358458842.43) - - def test_get_src_poly_swath(self): - """Test defining source chunk polygon for SwathDefinition.""" - chunks = (10, 10) - self.swath_resampler._get_projection_coordinates(chunks) - self.swath_resampler._get_gradients() - # SwathDefinition can't be sliced, so False is returned - poly = self.swath_resampler._get_src_poly(0, 40, 0, 40) - assert poly is False - - @mock.patch('pyresample.gradient.get_polygon') - def test_get_dst_poly_area(self, get_polygon): - """Test defining destination chunk polygon.""" - chunks = (10, 10) - self.resampler._get_projection_coordinates(chunks) - self.resampler._get_gradients() - # First call should make a call to get_polygon() - self.resampler._get_dst_poly('idx1', 0, 10, 0, 10) - assert get_polygon.call_count == 1 - assert 'idx1' in self.resampler.dst_polys - # The second call to the same index should come from cache - self.resampler._get_dst_poly('idx1', 0, 10, 0, 10) - assert get_polygon.call_count == 1 - - def test_get_dst_poly_swath(self): - """Test defining dst chunk polygon for SwathDefinition.""" - chunks = (10, 10) - self.area_to_swath_resampler._get_projection_coordinates(chunks) - self.area_to_swath_resampler._get_gradients() - # SwathDefinition can't be sliced, so False is returned - self.area_to_swath_resampler._get_dst_poly('idx2', 0, 10, 0, 10) - assert self.area_to_swath_resampler.dst_polys['idx2'] is False - - def test_filter_data(self): - """Test filtering chunks that do not overlap.""" - chunks = (10, 10) - self.resampler._get_projection_coordinates(chunks) - self.resampler._get_gradients() - self.resampler.get_chunk_mappings() - - # Basic filtering. There should be 8 dask arrays that each - # have a shape of (10, 10) - res = self.resampler._filter_data(self.resampler.src_x) - valid = [itm for itm in res if itm is not None] - assert len(valid) == 8 - shapes = [arr.shape for arr in valid] - for shp in shapes: - assert shp == (10, 10) - - # Destination x/y coordinate array filtering. Again, 8 dask - # arrays each with shape (102, 102) - res = self.resampler._filter_data(self.resampler.dst_x, is_src=False) - valid = [itm for itm in res if itm is not None] - assert len(valid) == 8 - shapes = [arr.shape for arr in valid] - for shp in shapes: - assert shp == (102, 102) - - # Add a dimension to the given dataset - data = da.random.random(self.src_area.shape) - res = self.resampler._filter_data(data, add_dim=True) - valid = [itm for itm in res if itm is not None] - assert len(valid) == 8 - shapes = [arr.shape for arr in valid] - for shp in shapes: - assert shp == (1, 10, 10) - - # 1D and 3+D should raise NotImplementedError - data = da.random.random((3,)) - try: - res = self.resampler._filter_data(data, add_dim=True) - raise IndexError - except NotImplementedError: - pass - data = da.random.random((3, 3, 3, 3)) - try: - res = self.resampler._filter_data(data, add_dim=True) - raise IndexError - except NotImplementedError: - pass - - def test_resample_area_to_area_2d(self): - """Resample area to area, 2d.""" - data = xr.DataArray(da.ones(self.src_area.shape, dtype=np.float64), - dims=['y', 'x']) - res = self.resampler.compute( - data, method='bil').compute(scheduler='single-threaded') - assert res.shape == self.dst_area.shape - assert np.allclose(res, 1) - - def test_resample_area_to_area_2d_fill_value(self): - """Resample area to area, 2d, use fill value.""" - data = xr.DataArray(da.full(self.src_area.shape, np.nan, - dtype=np.float64), dims=['y', 'x']) - res = self.resampler.compute( - data, method='bil', - fill_value=2.0).compute(scheduler='single-threaded') - assert res.shape == self.dst_area.shape - assert np.allclose(res, 2.0) - - def test_resample_area_to_area_3d(self): - """Resample area to area, 3d.""" - data = xr.DataArray(da.ones((3, ) + self.src_area.shape, - dtype=np.float64) * - np.array([1, 2, 3])[:, np.newaxis, np.newaxis], - dims=['bands', 'y', 'x']) - res = self.resampler.compute( - data, method='bil').compute(scheduler='single-threaded') - assert res.shape == (3, ) + self.dst_area.shape - assert np.allclose(res[0, :, :], 1.0) - assert np.allclose(res[1, :, :], 2.0) - assert np.allclose(res[2, :, :], 3.0) - - def test_resample_area_to_area_3d_single_channel(self): - """Resample area to area, 3d with only a single band.""" - data = xr.DataArray(da.ones((1, ) + self.src_area.shape, - dtype=np.float64), - dims=['bands', 'y', 'x']) - res = self.resampler.compute( - data, method='bil').compute(scheduler='single-threaded') - assert res.shape == (1, ) + self.dst_area.shape - assert np.allclose(res[0, :, :], 1.0) - - @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) - def test_resample_swath_to_area_2d(self, input_dtype): - """Resample swath to area, 2d.""" - data = xr.DataArray(da.ones(self.src_swath.shape, dtype=input_dtype), - dims=['y', 'x']) - with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings - res_xr = self.swath_resampler.compute(data, method='bil') - res_np = res_xr.compute(scheduler='single-threaded') - - assert res_xr.dtype == data.dtype - assert res_np.dtype == data.dtype - assert res_xr.shape == self.dst_area.shape - assert res_np.shape == self.dst_area.shape - assert type(res_xr) is type(data) - assert type(res_xr.data) is type(data.data) - assert not np.all(np.isnan(res_np)) - - @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) - def test_resample_swath_to_area_3d(self, input_dtype): - """Resample area to area, 3d.""" - data = xr.DataArray(da.ones((3, ) + self.src_swath.shape, - dtype=input_dtype) * - np.array([1, 2, 3])[:, np.newaxis, np.newaxis], - dims=['bands', 'y', 'x']) - with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings - res_xr = self.swath_resampler.compute(data, method='bil') - res_np = res_xr.compute(scheduler='single-threaded') - - assert res_xr.dtype == data.dtype - assert res_np.dtype == data.dtype - assert res_xr.shape == (3, ) + self.dst_area.shape - assert res_np.shape == (3, ) + self.dst_area.shape - assert type(res_xr) is type(data) - assert type(res_xr.data) is type(data.data) - for i in range(res_np.shape[0]): - arr = np.ravel(res_np[i, :, :]) - assert np.allclose(arr[np.isfinite(arr)], float(i + 1)) - - class TestRBGradientSearchResamplerArea2Area: """Test RBGradientSearchResampler for the Area to Area case.""" @@ -509,6 +256,70 @@ def test_resample_area_to_area_nn(self): assert res.shape == dst_area.shape +class TestRBGradientSearchResamplerSwath2Area: + """Test RBGradientSearchResampler for the Area to Swath case.""" + + def setup_method(self): + """Set up the test case.""" + lons, lats = np.meshgrid(np.linspace(0, 20, 100), np.linspace(45, 66, 100)) + self.src_swath = SwathDefinition(lons, lats, crs="WGS84") + lons, lats = self.src_swath.get_lonlats(chunks=10) + lons = xr.DataArray(lons, dims=["y", "x"]) + lats = xr.DataArray(lats, dims=["y", "x"]) + self.src_swath_dask = SwathDefinition(lons, lats) + self.dst_area = AreaDefinition('euro40', 'euro40', None, + {'proj': 'stere', 'lon_0': 14.0, + 'lat_0': 90.0, 'lat_ts': 60.0, + 'ellps': 'bessel'}, + 102, 102, + (-2717181.7304994687, -5571048.14031214, + 1378818.2695005313, -1475048.1403121399)) + + @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) + def test_resample_swath_to_area_2d(self, input_dtype): + """Resample swath to area, 2d.""" + swath_resampler = ResampleBlocksGradientSearchResampler(self.src_swath_dask, self.dst_area) + + data = xr.DataArray(da.ones(self.src_swath.shape, dtype=input_dtype), + dims=['y', 'x']) + with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings + swath_resampler.precompute() + res_xr = swath_resampler.compute(data, method='bilinear') + res_np = res_xr.compute(scheduler='single-threaded') + + assert res_xr.dtype == data.dtype + assert res_np.dtype == data.dtype + assert res_xr.shape == self.dst_area.shape + assert res_np.shape == self.dst_area.shape + assert type(res_xr) is type(data) + assert type(res_xr.data) is type(data.data) + assert not np.all(np.isnan(res_np)) + + @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) + def test_resample_swath_to_area_3d(self, input_dtype): + """Resample area to area, 3d.""" + swath_resampler = ResampleBlocksGradientSearchResampler(self.src_swath_dask, self.dst_area) + + data = xr.DataArray(da.ones((3, ) + self.src_swath.shape, + dtype=input_dtype) * + np.array([1, 2, 3])[:, np.newaxis, np.newaxis], + dims=['bands', 'y', 'x']) + with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings + swath_resampler.precompute() + res_xr = swath_resampler.compute(data, method='bilinear') + res_np = res_xr.compute(scheduler='single-threaded') + + assert res_xr.dtype == data.dtype + assert res_np.dtype == data.dtype + assert res_xr.shape == (3, ) + self.dst_area.shape + assert res_np.shape == (3, ) + self.dst_area.shape + assert type(res_xr) is type(data) + assert type(res_xr.data) is type(data.data) + for i in range(res_np.shape[0]): + arr = np.ravel(res_np[i, :, :]) + assert np.allclose(arr[np.isfinite(arr)], float(i + 1)) + + class TestRBGradientSearchResamplerArea2Swath: """Test RBGradientSearchResampler for the Area to Swath case.""" @@ -827,21 +638,6 @@ def test_concatenate_chunks(): assert res.shape == (3, 8, 6) -@mock.patch('pyresample.gradient.da') -def test_concatenate_chunks_stack_calls(dask_da): - """Test that stacking is called the correct times in chunk concatenation.""" - from pyresample.gradient import _concatenate_chunks - - chunks = {(0, 0): [np.ones((1, 5, 4)), np.zeros((1, 5, 4))], - (1, 0): [np.zeros((1, 5, 2))], - (1, 1): [np.full((1, 3, 2), 0.5)], - (0, 1): [np.full((1, 3, 4), -1)]} - _ = _concatenate_chunks(chunks) - dask_da.stack.assert_called_once_with(chunks[(0, 0)], axis=-1) - dask_da.nanmax.assert_called_once() - assert 'axis=2' in str(dask_da.concatenate.mock_calls[-1]) - - class TestGradientCython(): """Test the core gradient features.""" From 2bef7293eff2b97a3bf0af903d79bcf41a3d6dbe Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 22 Oct 2024 11:24:31 +0200 Subject: [PATCH 2/6] Remove irrelevant test --- pyresample/test/test_resample_blocks.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pyresample/test/test_resample_blocks.py b/pyresample/test/test_resample_blocks.py index 5d472f8f..1cea5711 100644 --- a/pyresample/test/test_resample_blocks.py +++ b/pyresample/test/test_resample_blocks.py @@ -58,18 +58,6 @@ def setup_method(self): (-2717181.7304994687, -5571048.14031214, 1378818.2695005313, -1475048.1403121399)) - def test_resample_blocks_advises_on_using_mapblocks_when_source_and_destination_areas_are_the_same(self): - """Test resample_blocks advises on using map_blocks when the source and destination areas are the same.""" - from pyresample.resampler import resample_blocks - - def fun(data): - return data - - some_array = da.random.random(self.src_area.shape) - with pytest.raises(ValueError) as excinfo: - resample_blocks(fun, self.src_area, [some_array], self.src_area) - assert "map_blocks" in str(excinfo.value) - def test_resample_blocks_returns_array_with_destination_area_shape(self): """Test resample_blocks returns array with the shape of the destination area.""" from pyresample.resampler import resample_blocks From cfb478dc0c756b4b10e2ad11aa9f6741f56ae3d5 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 22 Oct 2024 11:27:31 +0200 Subject: [PATCH 3/6] Clean up old code --- pyresample/gradient/__init__.py | 45 --------------------------------- 1 file changed, 45 deletions(-) diff --git a/pyresample/gradient/__init__.py b/pyresample/gradient/__init__.py index ddc3e519..c8750108 100644 --- a/pyresample/gradient/__init__.py +++ b/pyresample/gradient/__init__.py @@ -59,27 +59,6 @@ def create_gradient_search_resampler(source_geo_def, target_geo_def): raise NotImplementedError -@da.as_gufunc(signature='(),()->(),()') -def transform(x_coords, y_coords, src_prj=None, dst_prj=None): - """Calculate projection coordinates.""" - transformer = pyproj.Transformer.from_crs(src_prj, dst_prj, always_xy=True) - return transformer.transform(x_coords, y_coords) - - -def check_overlap(src_poly, dst_poly): - """Check if the two polygons overlap.""" - # swath definition case - if dst_poly is False or src_poly is False: - covers = True - # area / area case - elif dst_poly is not None and src_poly is not None: - covers = src_poly.intersects(dst_poly) - # out of earth disk case - else: - covers = False - return covers - - def _gradient_resample_data(src_data, src_x, src_y, src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp, @@ -142,30 +121,6 @@ def _check_input_coordinates(dst_x, dst_y, raise ValueError("Target arrays should all have the same shape") -def get_border_lonlats(geo_def: AreaDefinition): - """Get the border x- and y-coordinates.""" - if geo_def.is_geostationary: - lon_b, lat_b = get_geostationary_bounding_box_in_lonlats(geo_def, 3600) - else: - lons, lats = geo_def.get_boundary_lonlats() - lon_b = np.concatenate((lons.side1, lons.side2, lons.side3, lons.side4)) - lat_b = np.concatenate((lats.side1, lats.side2, lats.side3, lats.side4)) - - return lon_b, lat_b - - -def get_polygon(prj, geo_def): - """Get border polygon from area definition in projection *prj*.""" - lon_b, lat_b = get_border_lonlats(geo_def) - x_borders, y_borders = prj(lon_b, lat_b) - boundary = [(x_borders[i], y_borders[i]) for i in range(len(x_borders)) - if np.isfinite(x_borders[i]) and np.isfinite(y_borders[i])] - poly = Polygon(boundary) - if np.isfinite(poly.area) and poly.area > 0.0: - return poly - return None - - def parallel_gradient_search(data, src_x, src_y, dst_x, dst_y, src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp, From 7c971b10cc4f4aa91be6c9a6a84099c4894c0c0e Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 22 Oct 2024 11:30:46 +0200 Subject: [PATCH 4/6] Clean up unneeded tests --- pyresample/test/test_gradient.py | 97 -------------------------------- 1 file changed, 97 deletions(-) diff --git a/pyresample/test/test_gradient.py b/pyresample/test/test_gradient.py index 90983847..c911f297 100644 --- a/pyresample/test/test_gradient.py +++ b/pyresample/test/test_gradient.py @@ -407,103 +407,6 @@ def fake_compute(arg1, data): decorated('bla', data) -def test_check_overlap(): - """Test overlap check returning correct results.""" - from shapely.geometry import Polygon - - from pyresample.gradient import check_overlap - - # If either of the polygons is False, True is returned - assert check_overlap(False, 3) is True - assert check_overlap('eggs', False) is True - assert check_overlap(False, False) is True - - # If either the polygons is None, False is returned - assert check_overlap(None, 'bacon') is False - assert check_overlap('spam', None) is False - assert check_overlap(None, None) is False - - # If the polygons overlap, True is returned - poly1 = Polygon(((0, 0), (0, 1), (1, 1), (1, 0))) - poly2 = Polygon(((-1, -1), (-1, 1), (1, 1), (1, -1))) - assert check_overlap(poly1, poly2) is True - - # If the polygons do not overlap, False is returned - poly2 = Polygon(((5, 5), (6, 5), (6, 6), (5, 6))) - assert check_overlap(poly1, poly2) is False - - -def test_get_border_lonlats_geos(): - """Test that correct methods are called in get_border_lonlats() with geos inputs.""" - from pyresample.gradient import get_border_lonlats - geo_def = AreaDefinition("", "", "", - "+proj=geos +h=1234567", 2, 2, [1, 2, 3, 4]) - with mock.patch("pyresample.gradient.get_geostationary_bounding_box_in_lonlats") as get_geostationary_bounding_box: - get_geostationary_bounding_box.return_value = 1, 2 - res = get_border_lonlats(geo_def) - assert res == (1, 2) - get_geostationary_bounding_box.assert_called_with(geo_def, 3600) - - -def test_get_border_lonlats(): - """Test that correct methods are called in get_border_lonlats().""" - from pyresample.boundary import SimpleBoundary - from pyresample.gradient import get_border_lonlats - lon_sides = SimpleBoundary(side1=np.array([1]), side2=np.array([2]), - side3=np.array([3]), side4=np.array([4])) - lat_sides = SimpleBoundary(side1=np.array([1]), side2=np.array([2]), - side3=np.array([3]), side4=np.array([4])) - geo_def = AreaDefinition("", "", "", - "+proj=lcc +lat_1=25 +lat_2=25", 2, 2, [1, 2, 3, 4]) - with mock.patch.object(geo_def, "get_boundary_lonlats") as get_boundary_lonlats: - get_boundary_lonlats.return_value = lon_sides, lat_sides - lon_b, lat_b = get_border_lonlats(geo_def) - assert np.all(lon_b == np.array([1, 2, 3, 4])) - assert np.all(lat_b == np.array([1, 2, 3, 4])) - - -@mock.patch('pyresample.gradient.Polygon') -@mock.patch('pyresample.gradient.get_border_lonlats') -def test_get_polygon(get_border_lonlats, Polygon): - """Test polygon creation.""" - from pyresample.gradient import get_polygon - - # Valid polygon - get_border_lonlats.return_value = (1, 2) - geo_def = mock.MagicMock() - prj = mock.MagicMock() - x_borders = [0, 0, 1, 1] - y_borders = [0, 1, 1, 0] - boundary = [(0, 0), (0, 1), (1, 1), (1, 0)] - prj.return_value = (x_borders, y_borders) - poly = mock.MagicMock(area=2.0) - Polygon.return_value = poly - res = get_polygon(prj, geo_def) - get_border_lonlats.assert_called_with(geo_def) - prj.assert_called_with(1, 2) - Polygon.assert_called_with(boundary) - assert res is poly - - # Some border points are invalid, those should have been removed - x_borders = [np.inf, 0, 0, 0, 1, np.nan, 2] - y_borders = [-1, 0, np.nan, 1, 1, np.nan, -1] - boundary = [(0, 0), (0, 1), (1, 1), (2, -1)] - prj.return_value = (x_borders, y_borders) - res = get_polygon(prj, geo_def) - Polygon.assert_called_with(boundary) - assert res is poly - - # Polygon area is NaN - poly.area = np.nan - res = get_polygon(prj, geo_def) - assert res is None - - # Polygon area is 0.0 - poly.area = 0.0 - res = get_polygon(prj, geo_def) - assert res is None - - @mock.patch('pyresample.gradient.one_step_gradient_search') def test_gradient_resample_data(one_step_gradient_search): """Test that one_step_gradient_search() is called with proper array shapes.""" From c174f88a540ee6e3b0659606fb4c99e2b61826df Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Tue, 22 Oct 2024 13:56:20 +0200 Subject: [PATCH 5/6] Add area to swath resampling using gradient search --- pyresample/gradient/__init__.py | 13 ++++-- pyresample/slicer.py | 5 +- pyresample/test/test_gradient.py | 78 ++++++++++++++++++++------------ 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/pyresample/gradient/__init__.py b/pyresample/gradient/__init__.py index c8750108..e60be439 100644 --- a/pyresample/gradient/__init__.py +++ b/pyresample/gradient/__init__.py @@ -186,7 +186,10 @@ def _concatenate_chunks(chunks): def _fill_in_coords(target_geo_def, data_coords, data_dims): - x_coord, y_coord = target_geo_def.get_proj_vectors() + try: + x_coord, y_coord = target_geo_def.get_proj_vectors() + except AttributeError: + return None coords = [] for key in data_dims: if key == 'x': @@ -219,8 +222,6 @@ class ResampleBlocksGradientSearchResampler(BaseResampler): def __init__(self, source_geo_def, target_geo_def): """Init GradientResampler.""" - if isinstance(target_geo_def, SwathDefinition): - raise NotImplementedError("Cannot resample to a SwathDefinition.") if isinstance(source_geo_def, SwathDefinition): source_geo_def.lons = source_geo_def.lons.persist() source_geo_def.lats = source_geo_def.lats.persist() @@ -325,11 +326,13 @@ def _get_coordinates_in_same_projection(source_area, target_area): except AttributeError as err: lons, lats = source_area.get_lonlats() src_x, src_y = da.compute(lons, lats) + transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True) try: - transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True) dst_x, dst_y = transformer.transform(*target_area.get_proj_coords()) except AttributeError as err: - raise NotImplementedError("Cannot resample to Swath for now.") from err + # target is a swath definition + lons, lats = target_area.get_lonlats() + dst_x, dst_y = transformer.transform(*da.compute(lons, lats)) src_gradient_xl, src_gradient_xp = np.gradient(src_x, axis=[0, 1]) src_gradient_yl, src_gradient_yp = np.gradient(src_y, axis=[0, 1]) return (dst_x, dst_y), (src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp), (src_x, src_y) diff --git a/pyresample/slicer.py b/pyresample/slicer.py index 579fd58a..dceb6976 100644 --- a/pyresample/slicer.py +++ b/pyresample/slicer.py @@ -148,7 +148,10 @@ class AreaSlicer(Slicer): def get_polygon_to_contain(self): """Get the shapely Polygon corresponding to *area_to_contain* in projection coordinates of *area_to_crop*.""" from shapely.geometry import Polygon - x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(frequency=10) + try: + x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(frequency=10) + except AttributeError: + x, y = self.area_to_contain.get_edge_lonlats(vertices_per_side=10) if self.area_to_crop.is_geostationary: x_geos, y_geos = get_geostationary_bounding_box_in_proj_coords(self.area_to_crop, 360) x_geos, y_geos = self._transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE) diff --git a/pyresample/test/test_gradient.py b/pyresample/test/test_gradient.py index c911f297..9706d713 100644 --- a/pyresample/test/test_gradient.py +++ b/pyresample/test/test_gradient.py @@ -325,7 +325,12 @@ class TestRBGradientSearchResamplerArea2Swath: def setup_method(self): """Set up the test case.""" - chunks = 20 + lons, lats = np.meshgrid(np.linspace(0, 20, 100), np.linspace(45, 66, 100)) + self.dst_swath = SwathDefinition(lons, lats, crs="WGS84") + lons, lats = self.dst_swath.get_lonlats(chunks=10) + lons = xr.DataArray(lons, dims=["y", "x"]) + lats = xr.DataArray(lats, dims=["y", "x"]) + self.dst_swath_dask = SwathDefinition(lons, lats) self.src_area = AreaDefinition('euro40', 'euro40', None, {'proj': 'stere', 'lon_0': 14.0, @@ -335,34 +340,49 @@ def setup_method(self): (-2717181.7304994687, -5571048.14031214, 1378818.2695005313, -1475048.1403121399)) - self.dst_area = AreaDefinition( - 'omerc_otf', - 'On-the-fly omerc area', - None, - {'alpha': '8.99811271718795', - 'ellps': 'sphere', - 'gamma': '0', - 'k': '1', - 'lat_0': '0', - 'lonc': '13.8096029486222', - 'proj': 'omerc', - 'units': 'm'}, - 50, 100, - (-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457) - ) - - self.lons, self.lats = self.dst_area.get_lonlats(chunks=chunks) - xrlons = xr.DataArray(self.lons.persist()) - xrlats = xr.DataArray(self.lats.persist()) - self.dst_swath = SwathDefinition(xrlons, xrlats) - - def test_resampling_to_swath_is_not_implemented(self): - """Test that resampling to swath is not working yet.""" - from pyresample.gradient import ResampleBlocksGradientSearchResampler - - with pytest.raises(NotImplementedError): - ResampleBlocksGradientSearchResampler(self.src_area, - self.dst_swath) + @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) + def test_resample_area_to_swath_2d(self, input_dtype): + """Resample swath to area, 2d.""" + swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask) + + data = xr.DataArray(da.ones(self.src_area.shape, dtype=input_dtype), + dims=['y', 'x']) + with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings + swath_resampler.precompute() + res_xr = swath_resampler.compute(data, method='bilinear') + res_np = res_xr.compute(scheduler='single-threaded') + + assert res_xr.dtype == data.dtype + assert res_np.dtype == data.dtype + assert res_xr.shape == self.dst_swath.shape + assert res_np.shape == self.dst_swath.shape + assert type(res_xr) is type(data) + assert type(res_xr.data) is type(data.data) + assert not np.all(np.isnan(res_np)) + + @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) + def test_resample_area_to_swath_3d(self, input_dtype): + """Resample area to area, 3d.""" + swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask) + + data = xr.DataArray(da.ones((3, ) + self.src_area.shape, + dtype=input_dtype) * + np.array([1, 2, 3])[:, np.newaxis, np.newaxis], + dims=['bands', 'y', 'x']) + with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings + swath_resampler.precompute() + res_xr = swath_resampler.compute(data, method='bilinear') + res_np = res_xr.compute(scheduler='single-threaded') + + assert res_xr.dtype == data.dtype + assert res_np.dtype == data.dtype + assert res_xr.shape == (3, ) + self.dst_swath.shape + assert res_np.shape == (3, ) + self.dst_swath.shape + assert type(res_xr) is type(data) + assert type(res_xr.data) is type(data.data) + for i in range(res_np.shape[0]): + arr = np.ravel(res_np[i, :, :]) + assert np.allclose(arr[np.isfinite(arr)], float(i + 1)) class TestEnsureDataArray(unittest.TestCase): From 4dd294877c9c24a69890d7634cefeb1f7c9c0270 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Thu, 24 Oct 2024 09:47:18 +0200 Subject: [PATCH 6/6] Refactor and test more --- pyresample/gradient/__init__.py | 36 ++++++++++++++--- pyresample/slicer.py | 40 +++++++++++++----- pyresample/test/test_gradient.py | 4 +- pyresample/test/test_slicer.py | 69 +++++++++++++++++++++++++++----- 4 files changed, 120 insertions(+), 29 deletions(-) diff --git a/pyresample/gradient/__init__.py b/pyresample/gradient/__init__.py index e60be439..28448b0d 100644 --- a/pyresample/gradient/__init__.py +++ b/pyresample/gradient/__init__.py @@ -53,12 +53,28 @@ def GradientSearchResampler(source_geo_def, target_geo_def): def create_gradient_search_resampler(source_geo_def, target_geo_def): """Create a gradient search resampler.""" - if ((isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)) or - (isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition))): + if (is_area_to_area(source_geo_def, target_geo_def) or + is_swath_to_area(source_geo_def, target_geo_def) or + is_area_to_swath(source_geo_def, target_geo_def)): return ResampleBlocksGradientSearchResampler(source_geo_def, target_geo_def) raise NotImplementedError +def is_area_to_area(source_geo_def, target_geo_def): + """Check if source is area and target is area.""" + return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition) + + +def is_swath_to_area(source_geo_def, target_geo_def): + """Check if source is swath and target is area.""" + return isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition) + + +def is_area_to_swath(source_geo_def, target_geo_def): + """Check if source is area and targed is swath.""" + return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, SwathDefinition) + + def _gradient_resample_data(src_data, src_x, src_y, src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp, @@ -323,13 +339,18 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar def _get_coordinates_in_same_projection(source_area, target_area): try: src_x, src_y = source_area.get_proj_coords() - except AttributeError as err: + work_crs = source_area.crs + except AttributeError: + # source is a swath definition, use target crs instead lons, lats = source_area.get_lonlats() src_x, src_y = da.compute(lons, lats) - transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True) + trans = pyproj.Transformer.from_crs(source_area.crs, target_area.crs, always_xy=True) + src_x, src_y = trans.transform(src_x, src_y) + work_crs = target_area.crs + transformer = pyproj.Transformer.from_crs(target_area.crs, work_crs, always_xy=True) try: dst_x, dst_y = transformer.transform(*target_area.get_proj_coords()) - except AttributeError as err: + except AttributeError: # target is a swath definition lons, lats = target_area.get_lonlats() dst_x, dst_y = transformer.transform(*da.compute(lons, lats)) @@ -345,6 +366,9 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info= weight_l, l_start = np.modf(y_indices.clip(0, data.shape[-2] - 1)) weight_p, p_start = np.modf(x_indices.clip(0, data.shape[-1] - 1)) + weight_l = weight_l.astype(data.dtype) + weight_p = weight_p.astype(data.dtype) + l_start = l_start.astype(int) p_start = p_start.astype(int) l_end = np.clip(l_start + 1, 1, data.shape[-2] - 1) @@ -353,7 +377,7 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info= res = ((1 - weight_l) * (1 - weight_p) * data[..., l_start, p_start] + (1 - weight_l) * weight_p * data[..., l_start, p_end] + weight_l * (1 - weight_p) * data[..., l_end, p_start] + - weight_l * weight_p * data[..., l_end, p_end]).astype(data.dtype) + weight_l * weight_p * data[..., l_end, p_end]) res = np.where(mask, fill_value, res) return res diff --git a/pyresample/slicer.py b/pyresample/slicer.py index dceb6976..0f6bcc52 100644 --- a/pyresample/slicer.py +++ b/pyresample/slicer.py @@ -67,11 +67,13 @@ class Slicer(ABC): """ - def __init__(self, area_to_crop, area_to_contain): + def __init__(self, area_to_crop, area_to_contain, work_crs): """Set up the Slicer.""" self.area_to_crop = area_to_crop self.area_to_contain = area_to_contain - self._transformer = Transformer.from_crs(self.area_to_contain.crs, self.area_to_crop.crs, always_xy=True) + + self._source_transformer = Transformer.from_crs(self.area_to_contain.crs, work_crs, always_xy=True) + self._target_transformer = Transformer.from_crs(self.area_to_crop.crs, work_crs, always_xy=True) def get_slices(self): """Get the slices to crop *area_to_crop* enclosing *area_to_contain*.""" @@ -92,17 +94,23 @@ def get_slices_from_polygon(self, poly): class SwathSlicer(Slicer): """A Slicer for cropping SwathDefinitions.""" + def __init__(self, area_to_crop, area_to_contain, work_crs=None): + """Set up the Slicer.""" + if work_crs is None: + work_crs = area_to_contain.crs + super().__init__(area_to_crop, area_to_contain, work_crs) + def get_polygon_to_contain(self): """Get the shapely Polygon corresponding to *area_to_contain* in lon/lat coordinates.""" from shapely.geometry import Polygon x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(10) - poly = Polygon(zip(*self._transformer.transform(x, y))) + poly = Polygon(zip(*self._source_transformer.transform(x, y))) return poly def get_slices_from_polygon(self, poly): """Get the slices based on the polygon.""" intersecting_chunk_slices = [] - for smaller_poly, slices in _get_chunk_polygons_for_swath_to_crop(self.area_to_crop): + for smaller_poly, slices in self._get_chunk_polygons_for_swath_to_crop(self.area_to_crop): if smaller_poly.intersects(poly): intersecting_chunk_slices.append(slices) if not intersecting_chunk_slices: @@ -118,12 +126,18 @@ def _assemble_slices(chunk_slices): slices = col_slice, line_slice return slices + def _get_chunk_polygons_for_swath_to_crop(self, swath_to_crop): + """Get the polygons for each chunk of the area_to_crop.""" + from shapely.geometry import Polygon + for ((lons, lats), (line_slice, col_slice)) in _get_chunk_bboxes_for_swath_to_crop(swath_to_crop): + smaller_poly = Polygon(zip(*self._target_transformer.transform(lons, lats))) + yield (smaller_poly, (line_slice, col_slice)) + @lru_cache(maxsize=10) -def _get_chunk_polygons_for_swath_to_crop(swath_to_crop): - """Get the polygons for each chunk of the area_to_crop.""" +def _get_chunk_bboxes_for_swath_to_crop(swath_to_crop): + """Get the lon/lat bouding boxes for each chunk of the area_to_crop.""" res = [] - from shapely.geometry import Polygon src_chunks = swath_to_crop.lons.chunks for _position, (line_slice, col_slice) in _enumerate_chunk_slices(src_chunks): line_slice = expand_slice(line_slice) @@ -132,8 +146,7 @@ def _get_chunk_polygons_for_swath_to_crop(swath_to_crop): lons, lats = smaller_swath.get_edge_lonlats(10) lons = np.hstack(lons) lats = np.hstack(lats) - smaller_poly = Polygon(zip(lons, lats)) - res.append((smaller_poly, (line_slice, col_slice))) + res.append(((lons, lats), (line_slice, col_slice))) return res @@ -145,6 +158,11 @@ def expand_slice(small_slice): class AreaSlicer(Slicer): """A Slicer for cropping AreaDefinitions.""" + def __init__(self, area_to_crop, area_to_contain): + """Set up the Slicer.""" + work_crs = area_to_crop.crs + super().__init__(area_to_crop, area_to_contain, work_crs) + def get_polygon_to_contain(self): """Get the shapely Polygon corresponding to *area_to_contain* in projection coordinates of *area_to_crop*.""" from shapely.geometry import Polygon @@ -154,7 +172,7 @@ def get_polygon_to_contain(self): x, y = self.area_to_contain.get_edge_lonlats(vertices_per_side=10) if self.area_to_crop.is_geostationary: x_geos, y_geos = get_geostationary_bounding_box_in_proj_coords(self.area_to_crop, 360) - x_geos, y_geos = self._transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE) + x_geos, y_geos = self._source_transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE) geos_poly = Polygon(zip(x_geos, y_geos)) poly = Polygon(zip(x, y)) poly = poly.intersection(geos_poly) @@ -162,7 +180,7 @@ def get_polygon_to_contain(self): raise IncompatibleAreas('No slice on area.') x, y = zip(*poly.exterior.coords) - return Polygon(zip(*self._transformer.transform(x, y))) + return Polygon(zip(*self._source_transformer.transform(x, y))) def get_slices_from_polygon(self, poly_to_contain): """Get the slices based on the polygon.""" diff --git a/pyresample/test/test_gradient.py b/pyresample/test/test_gradient.py index 9706d713..5d93b17b 100644 --- a/pyresample/test/test_gradient.py +++ b/pyresample/test/test_gradient.py @@ -32,7 +32,7 @@ from pyresample.area_config import create_area_def from pyresample.geometry import AreaDefinition, SwathDefinition -from pyresample.gradient import ResampleBlocksGradientSearchResampler +from pyresample.gradient import ResampleBlocksGradientSearchResampler, create_gradient_search_resampler class TestRBGradientSearchResamplerArea2Area: @@ -343,7 +343,7 @@ def setup_method(self): @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) def test_resample_area_to_swath_2d(self, input_dtype): """Resample swath to area, 2d.""" - swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask) + swath_resampler = create_gradient_search_resampler(self.src_area, self.dst_swath_dask) data = xr.DataArray(da.ones(self.src_area.shape, dtype=input_dtype), dims=['y', 'x']) diff --git a/pyresample/test/test_slicer.py b/pyresample/test/test_slicer.py index af082f89..5868c363 100644 --- a/pyresample/test/test_slicer.py +++ b/pyresample/test/test_slicer.py @@ -223,15 +223,12 @@ def setUp(self): (-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457) ) - self.lons, self.lats = self.src_area.get_lonlats(chunks=chunks) - xrlons = xr.DataArray(self.lons.persist()) - xrlats = xr.DataArray(self.lats.persist()) - self.src_swath = SwathDefinition(xrlons, xrlats) + self.src_swath = swath_from_area(self.src_area, chunks) def test_slicer_init(self): """Test slicer initialization.""" slicer = create_slicer(self.src_swath, self.dst_area) - assert slicer.area_to_crop == self.src_area + assert slicer.area_to_crop == self.src_swath assert slicer.area_to_contain == self.dst_area def test_source_swath_slicing_does_not_return_full_dataset(self): @@ -246,17 +243,61 @@ def test_source_swath_slicing_does_not_return_full_dataset(self): def test_source_area_slicing_does_not_return_full_dataset(self): """Test source area covers dest area.""" - slicer = create_slicer(self.src_area, self.dst_area) + slicer = create_slicer(self.src_swath, self.dst_area) x_slice, y_slice = slicer.get_slices() assert x_slice.start == 0 - assert x_slice.stop == 35 - assert y_slice.start == 16 - assert y_slice.stop == 94 + assert x_slice.stop == 41 + assert y_slice.start == 9 + assert y_slice.stop == 91 + + def test_source_area_slicing_over_date_line(self): + src_area = AreaDefinition( + 'omerc_otf', + 'On-the-fly omerc area', + None, + {'alpha': '8.99811271718795', + 'ellps': 'sphere', + 'gamma': '0', + 'k': '1', + 'lat_0': '0', + 'lonc': '179.8096029486222', + 'proj': 'omerc', + 'units': 'm'}, + 50, 100, + (-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457) + ) + chunks = 10 + src_swath = swath_from_area(src_area, chunks) + + dst_area = AreaDefinition('somewhere in the pacific', 'somewhere', None, + {'proj': 'stere', 'lon_0': 180.0, + 'lat_0': 90.0, 'lat_ts': 60.0, + 'ellps': 'bessel'}, + 102, 102, + (-2717181.7304994687, -5571048.14031214, + 1378818.2695005313, -1475048.1403121399)) + + slicer = create_slicer(src_swath, dst_area) + x_slice, y_slice = slicer.get_slices() + assert x_slice.start == 0 + assert x_slice.stop == 41 + assert y_slice.start == 9 + assert y_slice.stop == 91 + + def test_source_area_slicing_with_custom_work_crs(self): + """Test source area covers dest area.""" + from pyresample.slicer import SwathSlicer + slicer = SwathSlicer(self.src_swath, self.dst_area, work_crs=self.src_area.crs) + x_slice, y_slice = slicer.get_slices() + assert x_slice.start == 0 + assert x_slice.stop == 41 + assert y_slice.start == 9 + assert y_slice.stop == 91 def test_area_get_polygon_returns_a_polygon(self): """Test getting a polygon returns a polygon.""" from shapely.geometry import Polygon - slicer = create_slicer(self.src_area, self.dst_area) + slicer = create_slicer(self.src_swath, self.dst_area) poly = slicer.get_polygon_to_contain() assert isinstance(poly, Polygon) @@ -271,3 +312,11 @@ def test_cannot_slice_a_string(self): """Test that we cannot slice a string.""" with pytest.raises(NotImplementedError): create_slicer("my_funky_area", self.dst_area) + + +def swath_from_area(src_area, chunks): + """Create a SwathDefinition from an AreaDefinition.""" + lons, lats = src_area.get_lonlats(chunks=chunks) + xrlons = xr.DataArray(lons.persist()) + xrlats = xr.DataArray(lats.persist()) + return SwathDefinition(xrlons, xrlats)