diff --git a/xee/ext.py b/xee/ext.py index 65bd9e0..b094cbe 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -89,6 +89,17 @@ class EarthEngineStore(common.AbstractDataStore): DEFAULT_MASK_VALUE = np.iinfo(np.int32).max + ATTRS_VALID_TYPES = ( + str, + int, + float, + complex, + np.ndarray, + np.number, + list, + tuple + ) + @classmethod def open( cls, @@ -164,7 +175,7 @@ def __init__( coordinates=f'{self.primary_dim_name} {x_dim_name} {y_dim_name}', crs=self.crs_arg, ) - + self._props = self._make_attrs_valid(self._props) # Scale in the projection's units. Typically, either meters or degrees. # If we use the default CRS i.e. EPSG:3857, the units is in meters. default_scale = self.SCALE_UNITS.get(self.scale_units, 1) @@ -324,13 +335,28 @@ def _band_attrs(self, band_name: str) -> types.BandInfo: def _bands(self) -> list[str]: return [b['id'] for b in self._img_info['bands']] + def _make_attrs_valid( + self, attrs: dict[str, Any] + ) -> dict[ + str, + Union[ + str, int, float, complex, np.ndarray, np.number, list[Any], tuple[Any] + ], + ]: + return { + key: (str(value) + if not isinstance(value, self.ATTRS_VALID_TYPES) + else value) + for key, value in attrs.items() + } + def open_store_variable(self, name: str) -> xarray.Variable: arr = EarthEngineBackendArray(name, self) data = indexing.LazilyIndexedArray(arr) x_dim_name, y_dim_name = self.dimension_names dimensions = [self.primary_dim_name, x_dim_name, y_dim_name] - attrs = self._band_attrs(name) + attrs = self._make_attrs_valid(self._band_attrs(name)) encoding = { 'source': attrs['id'], 'scale_factor': arr.scale, diff --git a/xee/ext_integration_test.py b/xee/ext_integration_test.py index 0f69ad0..4e5b089 100644 --- a/xee/ext_integration_test.py +++ b/xee/ext_integration_test.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== r"""Integration tests for the Google Earth Engine backend for Xarray.""" - import pathlib from absl.testing import absltest @@ -358,6 +357,23 @@ def test_data_sanity_check(self): self.assertNotEqual(temperature_2m.min(), 0.0) self.assertNotEqual(temperature_2m.max(), 0.0) + def test_validate_band_attrs(self): + ds = self.entry.open_dataset( + 'ee:LANDSAT/LC08/C01/T1', + drop_variables=tuple(f'B{i}' for i in range(3, 12)), + scale=25.0, # in degrees + n_images=3, + ) + valid_types = (str, int, float, complex, np.ndarray, np.number, list, tuple) + + # Check attrs on the dataset itself + for _, value in ds.attrs.items(): + self.assertIsInstance(value, valid_types) + + # Check attrs on each variable within the dataset + for variable in ds.variables.values(): + for _, value in variable.attrs.items(): + self.assertIsInstance(value, valid_types) if __name__ == '__main__': absltest.main()