Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transform dataset coordinates into specified crs #97

Merged
merged 9 commits into from
Dec 21, 2023
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ tests = [
"absl-py",
"pytest",
"pyink",
"rasterio",
"rioxarray",
]
examples = [
"apache_beam[gcp]",
Expand Down
12 changes: 5 additions & 7 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,11 +546,11 @@ def _get_tile_from_ee(
tile_index, band_id = tile_index
bbox = self.project(
(tile_index[0], 0, tile_index[1], 1)
if band_id == 'longitude'
if band_id == 'x'
else (0, tile_index[0], 1, tile_index[1])
)
tile_idx = slice(tile_index[0], tile_index[1])
target_image = ee.Image.pixelLonLat()
target_image = ee.Image.pixelCoordinates(ee.Projection(self.crs_arg))
return tile_idx, self.image_to_array(
target_image, grid=bbox, dtype=np.float32, bandIds=[band_id]
)
Expand All @@ -573,9 +573,7 @@ def _process_coordinate_data(
self._get_tile_from_ee,
list(zip(data, itertools.cycle([coordinate_type]))),
):
tiles[i] = (
arr.tolist() if coordinate_type == 'longitude' else arr.tolist()[0]
)
tiles[i] = arr.tolist() if coordinate_type == 'x' else arr.tolist()[0]
return np.concatenate(tiles)

def get_variables(self) -> utils.Frozen[str, xarray.Variable]:
Expand Down Expand Up @@ -604,11 +602,11 @@ def get_variables(self) -> utils.Frozen[str, xarray.Variable]:

lon_total_tile = math.ceil(v0.shape[1] / width_chunk)
lon = self._process_coordinate_data(
lon_total_tile, width_chunk, v0.shape[1], 'longitude'
lon_total_tile, width_chunk, v0.shape[1], 'x'
)
lat_total_tile = math.ceil(v0.shape[2] / height_chunk)
lat = self._process_coordinate_data(
lat_total_tile, height_chunk, v0.shape[2], 'latitude'
lat_total_tile, height_chunk, v0.shape[2], 'y'
)

width_coord = np.squeeze(lon)
Expand Down
74 changes: 74 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
import json
import os
import pathlib
import tempfile

from absl.testing import absltest
from google.auth import identity_pool
import numpy as np
import xarray as xr
from xarray.core import indexing
import rioxarray
import rasterio
import xee

import ee
Expand Down Expand Up @@ -397,6 +400,77 @@ def test_validate_band_attrs(self):
for _, value in variable.attrs.items():
self.assertIsInstance(value, valid_types)

def test_write_projected_dataset_to_raster(self):
# ensure that a projected dataset written to a raster intersects with the
# point used to create the initial image collection
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = os.path.join(temp_dir, 'test.tif')

crs = 'epsg:32610'
proj = ee.Projection(crs)
point = ee.Geometry.Point([-122.44, 37.78])
geom = point.buffer(1024).bounds()

col = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
col = col.filterBounds(point)
col = col.filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', 5))
col = col.limit(10)

ds = xr.open_dataset(
col,
engine=xee.EarthEngineBackendEntrypoint,
scale=10,
crs=crs,
geometry=geom,
)

ds = ds.isel(time=0).transpose('Y', 'X')
ds.rio.set_spatial_dims(x_dim='X', y_dim='Y', inplace=True)
ds.rio.write_crs(crs, inplace=True)
ds.rio.reproject(crs, inplace=True)
ds.rio.to_raster(temp_file)

with rasterio.open(temp_file) as raster:
# see https://gis.stackexchange.com/a/407755 for evenOdd explanation
bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False)
intersects = bbox.intersects(point, 1, proj=proj)
self.assertTrue(intersects.getInfo())

def test_write_dataset_to_raster(self):
# ensure that a dataset written to a raster intersects with the point used
# to create the initial image collection
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = os.path.join(temp_dir, 'test.tif')

crs = 'EPSG:4326'
proj = ee.Projection(crs)
point = ee.Geometry.Point([-122.44, 37.78])
geom = point.buffer(1024).bounds()

col = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
col = col.filterBounds(point)
col = col.filter(ee.Filter.lte('CLOUDY_PIXEL_PERCENTAGE', 5))
col = col.limit(10)

ds = xr.open_dataset(
col,
engine=xee.EarthEngineBackendEntrypoint,
scale=0.0025,
geometry=geom,
)

ds = ds.isel(time=0).transpose('lat', 'lon')
ds.rio.set_spatial_dims(x_dim='lon', y_dim='lat', inplace=True)
ds.rio.write_crs(crs, inplace=True)
ds.rio.reproject(crs, inplace=True)
ds.rio.to_raster(temp_file)

with rasterio.open(temp_file) as raster:
# see https://gis.stackexchange.com/a/407755 for evenOdd explanation
bbox = ee.Geometry.Rectangle(raster.bounds, proj=proj, evenOdd=False)
intersects = bbox.intersects(point, 1, proj=proj)
self.assertTrue(intersects.getInfo())


if __name__ == '__main__':
absltest.main()
Loading