diff --git a/pyproject.toml b/pyproject.toml index 942cb56..81a5fcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ tests = [ "absl-py", "pytest", "pyink", + "rasterio", + "rioxarray", ] examples = [ "apache_beam[gcp]", diff --git a/xee/ext.py b/xee/ext.py index b6122e0..1618191 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -546,11 +546,11 @@ def _get_tile_from_ee( tile_index, band_id = tile_index bbox = self.project( (tile_index[0], 0, tile_index[1], 1) - if band_id == 'longitude' + if band_id == 'x' else (0, tile_index[0], 1, tile_index[1]) ) tile_idx = slice(tile_index[0], tile_index[1]) - target_image = ee.Image.pixelLonLat() + target_image = ee.Image.pixelCoordinates(ee.Projection(self.crs_arg)) return tile_idx, self.image_to_array( target_image, grid=bbox, dtype=np.float32, bandIds=[band_id] ) @@ -573,9 +573,7 @@ def _process_coordinate_data( self._get_tile_from_ee, list(zip(data, itertools.cycle([coordinate_type]))), ): - tiles[i] = ( - arr.tolist() if coordinate_type == 'longitude' else arr.tolist()[0] - ) + tiles[i] = arr.tolist() if coordinate_type == 'x' else arr.tolist()[0] return np.concatenate(tiles) def get_variables(self) -> utils.Frozen[str, xarray.Variable]: @@ -604,11 +602,11 @@ def get_variables(self) -> utils.Frozen[str, xarray.Variable]: lon_total_tile = math.ceil(v0.shape[1] / width_chunk) lon = self._process_coordinate_data( - lon_total_tile, width_chunk, v0.shape[1], 'longitude' + lon_total_tile, width_chunk, v0.shape[1], 'x' ) lat_total_tile = math.ceil(v0.shape[2] / height_chunk) lat = self._process_coordinate_data( - lat_total_tile, height_chunk, v0.shape[2], 'latitude' + lat_total_tile, height_chunk, v0.shape[2], 'y' ) width_coord = np.squeeze(lon) diff --git a/xee/ext_integration_test.py b/xee/ext_integration_test.py index e727465..d1eafc0 100644 --- a/xee/ext_integration_test.py +++ b/xee/ext_integration_test.py @@ -16,6 +16,7 @@ import json import os import pathlib +import tempfile from absl.testing import absltest from google.auth import identity_pool @@ -26,6 +27,13 @@ import ee +_SKIP_RASTERIO_TESTS = False +try: + import rasterio # pylint: disable=g-import-not-at-top + import rioxarray # pylint: disable=g-import-not-at-top,unused-import +except ImportError: + _SKIP_RASTERIO_TESTS = True + _CREDENTIALS_PATH_KEY = 'GOOGLE_APPLICATION_CREDENTIALS' _SCOPES = [ 'https://www.googleapis.com/auth/cloud-platform', @@ -397,6 +405,79 @@ def test_validate_band_attrs(self): for _, value in variable.attrs.items(): self.assertIsInstance(value, valid_types) + @absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded') + def test_write_projected_dataset_to_raster(self): + # ensure that a projected dataset written to a raster intersects with the + # point used to create the initial image collection + with tempfile.TemporaryDirectory() as temp_dir: + temp_file = os.path.join(temp_dir, 'test.tif') + + crs = 'epsg:32610' + proj = ee.Projection(crs) + point = ee.Geometry.Point([-122.44, 37.78]) + geom = point.buffer(1024).bounds() + + col = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') + col = col.filterBounds(point) + col = col.filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', 5)) + col = col.limit(10) + + ds = xr.open_dataset( + col, + engine=xee.EarthEngineBackendEntrypoint, + scale=10, + crs=crs, + geometry=geom, + ) + + ds = ds.isel(time=0).transpose('Y', 'X') + ds.rio.set_spatial_dims(x_dim='X', y_dim='Y', inplace=True) + ds.rio.write_crs(crs, inplace=True) + ds.rio.reproject(crs, inplace=True) + ds.rio.to_raster(temp_file) + + with rasterio.open(temp_file) as raster: + # see https://gis.stackexchange.com/a/407755 for evenOdd explanation + bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False) + intersects = bbox.intersects(point, 1, proj=proj) + self.assertTrue(intersects.getInfo()) + + @absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded') + def test_write_dataset_to_raster(self): + # ensure that a dataset written to a raster intersects with the point used + # to create the initial image collection + with tempfile.TemporaryDirectory() as temp_dir: + temp_file = os.path.join(temp_dir, 'test.tif') + + crs = 'EPSG:4326' + proj = ee.Projection(crs) + point = ee.Geometry.Point([-122.44, 37.78]) + geom = point.buffer(1024).bounds() + + col = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') + col = col.filterBounds(point) + col = col.filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', 5)) + col = col.limit(10) + + ds = xr.open_dataset( + col, + engine=xee.EarthEngineBackendEntrypoint, + scale=0.0025, + geometry=geom, + ) + + ds = ds.isel(time=0).transpose('lat', 'lon') + ds.rio.set_spatial_dims(x_dim='lon', y_dim='lat', inplace=True) + ds.rio.write_crs(crs, inplace=True) + ds.rio.reproject(crs, inplace=True) + ds.rio.to_raster(temp_file) + + with rasterio.open(temp_file) as raster: + # see https://gis.stackexchange.com/a/407755 for evenOdd explanation + bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False) + intersects = bbox.intersects(point, 1, proj=proj) + self.assertTrue(intersects.getInfo()) + if __name__ == '__main__': absltest.main()