From 408c1f5136f2e5c61d23a002db7c304d97703dc7 Mon Sep 17 00:00:00 2001 From: Xee authors Date: Tue, 10 Oct 2023 15:29:14 -0700 Subject: [PATCH] Automatically calculate optimal IO Chunks. By default, users should be able to get all 48 MBs worth of data in each request to EE. In this change, we change the default `io_chunk` behavior to make an educated guess such that users get as many bytes as possible under the request byte limit. Fixes #43. PiperOrigin-RevId: 572384058 --- xee/ext.py | 67 ++++++++++++++++++++++++++++++++++++-- xee/ext_test.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 xee/ext_test.py diff --git a/xee/ext.py b/xee/ext.py index 20df379..ab22e5b 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -64,6 +64,24 @@ 'double': np.float64, } +# While this documentation says that the limit is 10 MB... +# https://developers.google.com/earth-engine/guides/usage#request_payload_size +# actual byte limit seems to depend on other factors. This has been found via +# trial & error. +REQUEST_BYTE_LIMIT = 2**20 * 48 # 48 MBs + + +def _check_request_limit(chunks: dict[str, int], dtype_size: int, limit: int): + """Checks that the actual number of bytes exceeds the limit.""" + index, width, height = chunks['index'], chunks['width'], chunks['height'] + actual_bytes = index * width * height * dtype_size + if actual_bytes > limit: + raise ValueError( + f'`chunks="auto"` failed! Actual bytes {actual_bytes!r} exceeds limit' + f' {limit!r}. Please choose another value for `chunks` (and file a' + ' bug).' + ) + class _GetComputedPixels: """Wrapper around `ee.data.computePixels()` to make retries simple.""" @@ -121,6 +139,7 @@ def open( primary_dim_name: Optional[str] = None, primary_dim_property: Optional[str] = None, mask_value: Optional[float] = None, + request_byte_limit: int = REQUEST_BYTE_LIMIT, ) -> 'EarthEngineStore': if mode != 'r': raise ValueError( @@ -138,6 +157,7 @@ def open( primary_dim_name=primary_dim_name, primary_dim_property=primary_dim_property, mask_value=mask_value, + request_byte_limit=request_byte_limit, ) def __init__( @@ -152,6 +172,7 @@ def __init__( primary_dim_name: Optional[str] = None, primary_dim_property: Optional[str] = None, mask_value: Optional[float] = None, + request_byte_limit: int = REQUEST_BYTE_LIMIT, ): self.image_collection = image_collection if n_images != -1: @@ -220,10 +241,13 @@ def __init__( x_max, y_max = self.transform(x_max_0, y_max_0) self.bounds = x_min, y_min, x_max, y_max - self.chunks = self.PREFERRED_CHUNKS.copy() + max_dtype = self._max_itemsize() + + # TODO(b/291851322): Consider support for laziness when chunks=None. + # By default, automatically optimize io_chunks. + self.chunks = self._auto_chunks(max_dtype, request_byte_limit) if chunks == -1: self.chunks = -1 - # TODO(b/291851322): Consider support for laziness when chunks=None. elif chunks is not None and chunks != 'auto': self.chunks = self._assign_index_chunks(chunks) @@ -282,6 +306,38 @@ def image_ids(self) -> list[str]: image_ids, _ = self.image_collection_properties return image_ids + def _max_itemsize(self) -> int: + return max( + _parse_dtype(b['data_type']).itemsize for b in self._img_info['bands'] + ) + + @classmethod + def _auto_chunks( + cls, dtype_bytes: int, request_byte_limit: int = REQUEST_BYTE_LIMIT + ) -> dict[str, int]: + """Given the data type size and request limit, calculate optimal chunks.""" + # Taking the data type number of bytes into account, let's try to have the + # height and width follow round numbers (powers of two) and allocate the + # remaining bytes available for the index length. To illustrate this logic, + # let's follow through with an example where: + # request_byte_limit = 2 ** 20 * 10 # = 10 MBs + # dtype_bytes = 8 + log_total = np.log2(request_byte_limit) # e.g.=23.32... + log_dtype = np.log2(dtype_bytes) # e.g.=3 + log_limit = 10 * (log_total // 10) # e.g.=20 + log_index = log_total - log_limit # e.g.=3.32... + + # Motivation: How do we divide a number N into the closest sum of two ints? + d = (log_limit - np.ceil(log_dtype)) / 2 # e.g.=17/2=8.5 + wd, ht = np.ceil(d), np.floor(d) # e.g. wd=9, ht=8 + + # Put back to byte space, then round to the nearst integer number of bytes. + index = int(np.rint(2**log_index)) # e.g.=10 + width = int(np.rint(2**wd)) # e.g.=512 + height = int(np.rint(2**ht)) # e.g.=256 + + return {'index': index, 'width': width, 'height': height} + def _assign_index_chunks( self, input_chunk_store: dict[Any, Any] ) -> dict[Any, Any]: @@ -808,6 +864,7 @@ def open_dataset( primary_dim_name: Optional[str] = None, primary_dim_property: Optional[str] = None, ee_mask_value: Optional[float] = None, + request_byte_limit: int = REQUEST_BYTE_LIMIT, ) -> xarray.Dataset: """Open an Earth Engine ImageCollection as an Xarray Dataset. @@ -816,7 +873,8 @@ def open_dataset( ee.ImageCollection object. drop_variables (optional): Variables or bands to drop before opening. io_chunks (optional): Specifies the chunking strategy for loading data - from EE. + from EE. By default, this automatically calculates optional chunks based + on the `request_byte_limit`. n_images (optional): The max number of EE images in the collection to open. Useful when there are a large number of images in the collection since calculating collection size can be slow. -1 indicates that all @@ -869,6 +927,8 @@ def open_dataset( 'system:time_start'. ee_mask_value (optional): Value to mask to EE nodata values. By default, this is 'np.iinfo(np.int32).max' i.e. 2147483647. + request_byte_limit: the max allowed bytes to request at a time from Earth + Engine. By default, it is 48MBs. Returns: An xarray.Dataset that streams in remote data from Earth Engine. @@ -895,6 +955,7 @@ def open_dataset( primary_dim_name=primary_dim_name, primary_dim_property=primary_dim_property, mask_value=ee_mask_value, + request_byte_limit=request_byte_limit, ) store_entrypoint = backends_store.StoreBackendEntrypoint() diff --git a/xee/ext_test.py b/xee/ext_test.py new file mode 100644 index 0000000..74b47f4 --- /dev/null +++ b/xee/ext_test.py @@ -0,0 +1,85 @@ +"""Xee Unit Tests.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np +import xee + +from xee import ext + + +class EEStoreStandardDatatypesTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='int8', + dtype=np.dtype('int8'), + expected_chunks={'index': 48, 'width': 1024, 'height': 1024}, + ), + dict( + testcase_name='int32', + dtype=np.dtype('int32'), + expected_chunks={'index': 48, 'width': 512, 'height': 512}, + ), + dict( + testcase_name='int64', + dtype=np.dtype('int64'), + expected_chunks={'index': 48, 'width': 512, 'height': 256}, + ), + dict( + testcase_name='float32', + dtype=np.dtype('float32'), + expected_chunks={'index': 48, 'width': 512, 'height': 512}, + ), + dict( + testcase_name='float64', + dtype=np.dtype('float64'), + expected_chunks={'index': 48, 'width': 512, 'height': 256}, + ), + dict( + testcase_name='complex64', + dtype=np.dtype('complex64'), + expected_chunks={'index': 48, 'width': 512, 'height': 256}, + ), + ) + def test_auto_chunks__handles_standard_dtypes(self, dtype, expected_chunks): + self.assertEqual( + xee.EarthEngineStore._auto_chunks(dtype.itemsize), + expected_chunks, + '%r fails.' % dtype, + ) + + +class EEStoreTest(absltest.TestCase): + + def test_auto_chunks__handles_range_of_dtype_sizes(self): + dt = 0 + try: + for dt in range(1, 1024): + _ = xee.EarthEngineStore._auto_chunks(dt) + except ValueError: + self.fail(f'Could not handle data type size {dt}.') + + def test_auto_chunks__is_optimal_for_powers_of_two(self): + for p in range(10): + dt = 2**p + chunks = xee.EarthEngineStore._auto_chunks(dt) + self.assertEqual( + xee.REQUEST_BYTE_LIMIT, np.prod(list(chunks.values())) * dt + ) + + def test_exceeding_byte_limit__raises_error(self): + dtype_size = 8 + # does not fail + chunks = {'index': 48, 'width': 512, 'height': 256} + ext._check_request_limit(chunks, dtype_size, xee.REQUEST_BYTE_LIMIT) + + # fails + chunks = {'index': 1024, 'width': 1024, 'height': 1024} + with self.assertRaises(ValueError): + ext._check_request_limit(chunks, dtype_size, xee.REQUEST_BYTE_LIMIT) + + +if __name__ == '__main__': + absltest.main()