Skip to content

Commit

Permalink
Merge branch 'main' into translate_origin
Browse files Browse the repository at this point in the history
  • Loading branch information
dabhicusp committed Mar 29, 2024
2 parents 4d287b3 + 31bfded commit 36d36ef
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 29 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ i = ee.ImageCollection(ee.Image("LANDSAT/LC08/C02/T1_TOA/LC08_044034_20140318"))
ds = xarray.open_dataset(i, engine='ee')
```

Open any Earth Engine ImageCollection to match an existing transform:

```python
raster = rioxarray.open_rasterio(...) # assume crs + transform is set
ds = xr.open_dataset(
'ee://ECMWF/ERA5_LAND/HOURLY',
engine='ee',
geometry=tuple(raster.rio.bounds()), # must be in EPSG:4326
projection=ee.Projection(
crs=str(raster.rio.crs), transform=raster.rio.transform()[:6]
),
)
```

See [examples](https://github.com/google/Xee/tree/main/examples) or [docs](https://github.com/google/Xee/tree/main/docs) for more uses and integrations.

## How to run integration tests
Expand Down
80 changes: 51 additions & 29 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import math
import os
import sys
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union
from urllib import parse
import warnings

Expand Down Expand Up @@ -144,7 +144,7 @@ def open(
crs: Optional[str] = None,
scale: Optional[float] = None,
projection: Optional[ee.Projection] = None,
geometry: Optional[ee.Geometry] = None,
geometry: ee.Geometry | Tuple[float, float, float, float] | None = None,
primary_dim_name: Optional[str] = None,
primary_dim_property: Optional[str] = None,
mask_value: Optional[float] = None,
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(
crs: Optional[str] = None,
scale: Union[float, int, None] = None,
projection: Optional[ee.Projection] = None,
geometry: Optional[ee.Geometry] = None,
geometry: ee.Geometry | Tuple[float, float, float, float] | None = None,
primary_dim_name: Optional[str] = None,
primary_dim_property: Optional[str] = None,
mask_value: Optional[float] = None,
Expand Down Expand Up @@ -246,26 +246,7 @@ def __init__(
self.scale_x, self.scale_y = transform.a, transform.e
self.scale = np.sqrt(np.abs(transform.determinant))

# Parse the dataset bounds from the native projection (either from the CRS
# or the image geometry) and translate it to the representation that will be
# used for all internal `computePixels()` calls.
try:
if isinstance(geometry, ee.Geometry):
x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds(
self.get_info['bounds']
)
else:
x_min_0, y_min_0, x_max_0, y_max_0 = self.crs.area_of_use.bounds
except AttributeError:
# `area_of_use` is probable `None`. Parse the geometry from the first
# image instead (calculated in self.get_info())
x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds(
self.get_info['bounds']
)

x_min, y_min = self.transform(x_min_0, y_min_0)
x_max, y_max = self.transform(x_max_0, y_max_0)
self.bounds = x_min, y_min, x_max, y_max
self.bounds = self._determine_bounds(geometry=geometry)

max_dtype = self._max_itemsize()

Expand Down Expand Up @@ -297,9 +278,16 @@ def get_info(self) -> Dict[str, Any]:
rpcs.append(('projection', self.projection))

if isinstance(self.geometry, ee.Geometry):
rpcs.append(('bounds', self.geometry.bounds()))
rpcs.append(('bounds', self.geometry.bounds(1, proj=self.projection)))
else:
rpcs.append(('bounds', self.image_collection.first().geometry().bounds()))
rpcs.append(
(
'bounds',
self.image_collection.first()
.geometry()
.bounds(1, proj=self.projection),
)
)

# TODO(#29, #30): This RPC call takes the longest time to compute. This
# requires a full scan of the images in the collection, which happens on the
Expand Down Expand Up @@ -612,7 +600,7 @@ def _get_tile_from_ee(
)
target_image = ee.Image.pixelCoordinates(ee.Projection(self.crs_arg))
return tile_index, self.image_to_array(
target_image, grid=bbox, dtype=np.float32, bandIds=[band_id]
target_image, grid=bbox, dtype=np.float64, bandIds=[band_id]
)

def _process_coordinate_data(
Expand All @@ -636,6 +624,39 @@ def _process_coordinate_data(
tiles[i] = arr.flatten()
return np.concatenate(tiles)

def _determine_bounds(
self,
geometry: ee.Geometry | Tuple[float, float, float, float] | None = None,
) -> Tuple[float, float, float, float]:
if geometry is None:
try:
x_min_0, y_min_0, x_max_0, y_max_0 = self.crs.area_of_use.bounds
except AttributeError:
# `area_of_use` is probably `None`. Parse the geometry from the first
# image instead (calculated in self.get_info())
x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds(
self.get_info['bounds']
)
elif isinstance(geometry, ee.Geometry):
x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds(
self.get_info['bounds']
)
elif isinstance(geometry, Sequence):
if len(geometry) != 4:
raise ValueError(
'geometry must be a tuple or list of length 4, or a ee.Geometry, '
f'but got {geometry!r}'
)
x_min_0, y_min_0, x_max_0, y_max_0 = geometry
else:
raise ValueError(
'geometry must be a tuple or list of length 4, a ee.Geometry, or'
f' None but got {type(geometry)}'
)
x_min, y_min = self.transform(x_min_0, y_min_0)
x_max, y_max = self.transform(x_max_0, y_max_0)
return x_min, y_min, x_max, y_max

def get_variables(self) -> utils.Frozen[str, xarray.Variable]:
vars_ = [(name, self.open_store_variable(name)) for name in self._bands()]

Expand Down Expand Up @@ -719,7 +740,7 @@ def _parse_dtype(data_type: types.DataType):


def _ee_bounds_to_bounds(bounds: ee.Bounds) -> types.Bounds:
coords = np.array(bounds['coordinates'], dtype=np.float32)[0]
coords = np.array(bounds['coordinates'], dtype=np.float64)[0]
x_min, y_min, x_max, y_max = (
min(coords[:, 0]),
min(coords[:, 1]),
Expand Down Expand Up @@ -974,7 +995,7 @@ def open_dataset(
crs: Optional[str] = None,
scale: Union[float, int, None] = None,
projection: Optional[ee.Projection] = None,
geometry: Optional[ee.Geometry] = None,
geometry: ee.Geometry | Tuple[float, float, float, float] | None = None,
primary_dim_name: Optional[str] = None,
primary_dim_property: Optional[str] = None,
ee_mask_value: Optional[float] = None,
Expand Down Expand Up @@ -1035,7 +1056,8 @@ def open_dataset(
coalesce all variables upon opening. By default, the scale and reference
system is set by the the `crs` and `scale` arguments.
geometry (optional): Specify an `ee.Geometry` to define the regional
bounds when opening the data. When not set, the bounds are defined by
bounds when opening the data or a bbox specifying [x_min, y_min, x_max,
y_max] in EPSG:4326. When not set, the bounds are defined by
the CRS's 'area_of_use` boundaries. If those aren't present, the bounds
are derived from the geometry of the first image of the collection.
primary_dim_name (optional): Override the name of the primary dimension of
Expand Down
65 changes: 65 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,32 @@ def __getitem__(self, params):

self.assertEqual(getter.count, 3)

def test_geometry_bounds_with_and_without_projection(self):
image = (
ee.ImageCollection('LANDSAT/LC08/C01/T1')
.filterDate('2017-01-01', '2017-01-03')
.first()
)
point = ee.Geometry.Point(-40.2414893624401, 105.48790177216375)
distance = 311.5
scale = 5000
projection = ee.Projection('EPSG:4326', [1, 0, 0, 0, -1, 0]).atScale(scale)
image = image.reproject(projection)

geometry = point.buffer(distance, proj=projection).bounds(proj=projection)

data_store = xee.EarthEngineStore(
ee.ImageCollection(image),
projection=image.projection(),
geometry=geometry,
)
data_store_bounds = data_store.get_info['bounds']

self.assertNotEqual(geometry.bounds().getInfo(), data_store_bounds)
self.assertEqual(
geometry.bounds(1, proj=projection).getInfo(), data_store_bounds
)

def test_getitem_kwargs(self):
arr = xee.EarthEngineBackendArray('B4', self.store)
self.assertEqual(arr.store.getitem_kwargs['initial_delay'], 1500)
Expand Down Expand Up @@ -402,6 +428,45 @@ def test_honors_projection(self):
self.assertEqual(ds.dims, {'time': 4248, 'lon': 3600, 'lat': 1800})
self.assertNotEqual(ds.dims, standard_ds.dims)

@absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded')
def test_expected_precise_transform(self):
data = np.empty((162, 121), dtype=np.float32)
bbox = (
-53.94158617595226,
-12.078281822698678,
-53.67209159071253,
-11.714464132625046,
)
x_res = (bbox[2] - bbox[0]) / data.shape[1]
y_res = (bbox[3] - bbox[1]) / data.shape[0]
raster = xr.DataArray(
data,
coords={
'y': np.linspace(bbox[3], bbox[1] + x_res, data.shape[0]),
'x': np.linspace(bbox[0], bbox[2] - y_res, data.shape[1]),
},
dims=('y', 'x'),
)
raster.rio.write_crs('EPSG:4326', inplace=True)
ic = (
ee.ImageCollection('UCSB-CHG/CHIRPS/DAILY')
.filterDate(ee.DateRange('2014-01-01', '2014-01-02'))
.select('precipitation')
)
xee_dataset = xr.open_dataset(
ee.ImageCollection(ic),
engine='ee',
geometry=tuple(raster.rio.bounds()),
projection=ee.Projection(
crs=str(raster.rio.crs), transform=raster.rio.transform()[:6]
),
).rename({'lon': 'x', 'lat': 'y'})
self.assertNotEqual(abs(x_res), abs(y_res))
np.testing.assert_equal(
np.array(xee_dataset.rio.transform()),
np.array(raster.rio.transform()),
)

def test_parses_ee_url(self):
ds = self.entry.open_dataset(
'ee://LANDSAT/LC08/C01/T1',
Expand Down

0 comments on commit 36d36ef

Please sign in to comment.