Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically calculate optimal IO Chunks. #68

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@
'double': np.float64,
}

# While this documentation says that the limit is 10 MB...
# https://developers.google.com/earth-engine/guides/usage#request_payload_size
# actual byte limit seems to depend on other factors. This has been found via
# trial & error.
REQUEST_BYTE_LIMIT = 2**20 * 48 # 48 MBs


def _check_request_limit(chunks: dict[str, int], dtype_size: int, limit: int):
"""Checks that the actual number of bytes exceeds the limit."""
index, width, height = chunks['index'], chunks['width'], chunks['height']
actual_bytes = index * width * height * dtype_size
if actual_bytes > limit:
raise ValueError(
f'`chunks="auto"` failed! Actual bytes {actual_bytes!r} exceeds limit'
f' {limit!r}. Please choose another value for `chunks` (and file a'
' bug).'
)


class _GetComputedPixels:
"""Wrapper around `ee.data.computePixels()` to make retries simple."""
Expand Down Expand Up @@ -121,6 +139,7 @@ def open(
primary_dim_name: Optional[str] = None,
primary_dim_property: Optional[str] = None,
mask_value: Optional[float] = None,
request_byte_limit: int = REQUEST_BYTE_LIMIT,
) -> 'EarthEngineStore':
if mode != 'r':
raise ValueError(
Expand All @@ -138,6 +157,7 @@ def open(
primary_dim_name=primary_dim_name,
primary_dim_property=primary_dim_property,
mask_value=mask_value,
request_byte_limit=request_byte_limit,
)

def __init__(
Expand All @@ -152,6 +172,7 @@ def __init__(
primary_dim_name: Optional[str] = None,
primary_dim_property: Optional[str] = None,
mask_value: Optional[float] = None,
request_byte_limit: int = REQUEST_BYTE_LIMIT,
):
self.image_collection = image_collection
if n_images != -1:
Expand Down Expand Up @@ -220,10 +241,13 @@ def __init__(
x_max, y_max = self.transform(x_max_0, y_max_0)
self.bounds = x_min, y_min, x_max, y_max

self.chunks = self.PREFERRED_CHUNKS.copy()
max_dtype = self._max_itemsize()

# TODO(b/291851322): Consider support for laziness when chunks=None.
# By default, automatically optimize io_chunks.
self.chunks = self._auto_chunks(max_dtype, request_byte_limit)
if chunks == -1:
self.chunks = -1
# TODO(b/291851322): Consider support for laziness when chunks=None.
elif chunks is not None and chunks != 'auto':
self.chunks = self._assign_index_chunks(chunks)

Expand Down Expand Up @@ -282,6 +306,38 @@ def image_ids(self) -> list[str]:
image_ids, _ = self.image_collection_properties
return image_ids

def _max_itemsize(self) -> int:
return max(
_parse_dtype(b['data_type']).itemsize for b in self._img_info['bands']
)

@classmethod
def _auto_chunks(
cls, dtype_bytes: int, request_byte_limit: int = REQUEST_BYTE_LIMIT
) -> dict[str, int]:
"""Given the data type size and request limit, calculate optimal chunks."""
# Taking the data type number of bytes into account, let's try to have the
# height and width follow round numbers (powers of two) and allocate the
# remaining bytes available for the index length. To illustrate this logic,
# let's follow through with an example where:
# request_byte_limit = 2 ** 20 * 10 # = 10 MBs
# dtype_bytes = 8
log_total = np.log2(request_byte_limit) # e.g.=23.32...
log_dtype = np.log2(dtype_bytes) # e.g.=3
log_limit = 10 * (log_total // 10) # e.g.=20
log_index = log_total - log_limit # e.g.=3.32...

# Motivation: How do we divide a number N into the closest sum of two ints?
d = (log_limit - np.ceil(log_dtype)) / 2 # e.g.=17/2=8.5
wd, ht = np.ceil(d), np.floor(d) # e.g. wd=9, ht=8

# Put back to byte space, then round to the nearst integer number of bytes.
index = int(np.rint(2**log_index)) # e.g.=10
width = int(np.rint(2**wd)) # e.g.=512
height = int(np.rint(2**ht)) # e.g.=256

return {'index': index, 'width': width, 'height': height}

def _assign_index_chunks(
self, input_chunk_store: dict[Any, Any]
) -> dict[Any, Any]:
Expand Down Expand Up @@ -808,6 +864,7 @@ def open_dataset(
primary_dim_name: Optional[str] = None,
primary_dim_property: Optional[str] = None,
ee_mask_value: Optional[float] = None,
request_byte_limit: int = REQUEST_BYTE_LIMIT,
) -> xarray.Dataset:
"""Open an Earth Engine ImageCollection as an Xarray Dataset.

Expand All @@ -816,7 +873,8 @@ def open_dataset(
ee.ImageCollection object.
drop_variables (optional): Variables or bands to drop before opening.
io_chunks (optional): Specifies the chunking strategy for loading data
from EE.
from EE. By default, this automatically calculates optional chunks based
on the `request_byte_limit`.
n_images (optional): The max number of EE images in the collection to
open. Useful when there are a large number of images in the collection
since calculating collection size can be slow. -1 indicates that all
Expand Down Expand Up @@ -869,6 +927,8 @@ def open_dataset(
'system:time_start'.
ee_mask_value (optional): Value to mask to EE nodata values. By default,
this is 'np.iinfo(np.int32).max' i.e. 2147483647.
request_byte_limit: the max allowed bytes to request at a time from Earth
Engine. By default, it is 48MBs.

Returns:
An xarray.Dataset that streams in remote data from Earth Engine.
Expand All @@ -895,6 +955,7 @@ def open_dataset(
primary_dim_name=primary_dim_name,
primary_dim_property=primary_dim_property,
mask_value=ee_mask_value,
request_byte_limit=request_byte_limit,
)

store_entrypoint = backends_store.StoreBackendEntrypoint()
Expand Down
85 changes: 85 additions & 0 deletions xee/ext_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Xee Unit Tests."""

from absl.testing import absltest
from absl.testing import parameterized

import numpy as np
import xee

from xee import ext


class EEStoreStandardDatatypesTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(
testcase_name='int8',
dtype=np.dtype('int8'),
expected_chunks={'index': 48, 'width': 1024, 'height': 1024},
),
dict(
testcase_name='int32',
dtype=np.dtype('int32'),
expected_chunks={'index': 48, 'width': 512, 'height': 512},
),
dict(
testcase_name='int64',
dtype=np.dtype('int64'),
expected_chunks={'index': 48, 'width': 512, 'height': 256},
),
dict(
testcase_name='float32',
dtype=np.dtype('float32'),
expected_chunks={'index': 48, 'width': 512, 'height': 512},
),
dict(
testcase_name='float64',
dtype=np.dtype('float64'),
expected_chunks={'index': 48, 'width': 512, 'height': 256},
),
dict(
testcase_name='complex64',
dtype=np.dtype('complex64'),
expected_chunks={'index': 48, 'width': 512, 'height': 256},
),
)
def test_auto_chunks__handles_standard_dtypes(self, dtype, expected_chunks):
self.assertEqual(
xee.EarthEngineStore._auto_chunks(dtype.itemsize),
expected_chunks,
'%r fails.' % dtype,
)


class EEStoreTest(absltest.TestCase):

def test_auto_chunks__handles_range_of_dtype_sizes(self):
dt = 0
try:
for dt in range(1, 1024):
_ = xee.EarthEngineStore._auto_chunks(dt)
except ValueError:
self.fail(f'Could not handle data type size {dt}.')

def test_auto_chunks__is_optimal_for_powers_of_two(self):
for p in range(10):
dt = 2**p
chunks = xee.EarthEngineStore._auto_chunks(dt)
self.assertEqual(
xee.REQUEST_BYTE_LIMIT, np.prod(list(chunks.values())) * dt
)

def test_exceeding_byte_limit__raises_error(self):
dtype_size = 8
# does not fail
chunks = {'index': 48, 'width': 512, 'height': 256}
ext._check_request_limit(chunks, dtype_size, xee.REQUEST_BYTE_LIMIT)

# fails
chunks = {'index': 1024, 'width': 1024, 'height': 1024}
with self.assertRaises(ValueError):
ext._check_request_limit(chunks, dtype_size, xee.REQUEST_BYTE_LIMIT)


if __name__ == '__main__':
absltest.main()
Loading