-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Allow users to configure the internal thread pool (#11) #54
Changes from 5 commits
7de6eb7
2d464f2
d445578
4231270
e643352
4cd311d
ec80d19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -172,6 +172,7 @@ def __init__( | |
primary_dim_name: Optional[str] = None, | ||
primary_dim_property: Optional[str] = None, | ||
mask_value: Optional[float] = None, | ||
executor_kwargs: Optional[dict] = None, | ||
request_byte_limit: int = REQUEST_BYTE_LIMIT, | ||
): | ||
self.image_collection = image_collection | ||
|
@@ -203,6 +204,13 @@ def __init__( | |
coordinates=f'{self.primary_dim_name} {x_dim_name} {y_dim_name}', | ||
crs=self.crs_arg, | ||
) | ||
|
||
# Initialize executor_kwargs | ||
if executor_kwargs is None: | ||
self.executor_kwargs = {} | ||
else: | ||
self.executor_kwargs = executor_kwargs | ||
|
||
self._props = self._make_attrs_valid(self._props) | ||
# Scale in the projection's units. Typically, either meters or degrees. | ||
# If we use the default CRS i.e. EPSG:3857, the units is in meters. | ||
|
@@ -706,7 +714,7 @@ def reduce_bands(x, acc): | |
return target_image | ||
|
||
def _raw_indexing_method( | ||
self, key: tuple[Union[int, slice], ...] | ||
self, key: tuple[Union[int, slice], ...] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove extra space (to minimize the diff). |
||
) -> np.typing.ArrayLike: | ||
key, squeeze_axes = self._key_to_slices(key) | ||
|
||
|
@@ -756,8 +764,9 @@ def _raw_indexing_method( | |
for _ in range(shape[0]) | ||
] | ||
|
||
# TODO(#11): Allow users to configure this via kwargs. | ||
with concurrent.futures.ThreadPoolExecutor() as pool: | ||
|
||
# Pass executor_kwargs to ThreadPoolExecutor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment is a bit too verbose for my taste. |
||
with concurrent.futures.ThreadPoolExecutor(**self.executor_kwargs) as pool: | ||
for (i, j, k), arr in pool.map( | ||
self._make_tile, self._tile_indexes(key[0], bbox) | ||
): | ||
|
@@ -845,6 +854,7 @@ def open_dataset( | |
primary_dim_name: Optional[str] = None, | ||
primary_dim_property: Optional[str] = None, | ||
ee_mask_value: Optional[float] = None, | ||
executor_kwargs: Optional[dict] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition to adding it to the arguments, please update the docstring. |
||
request_byte_limit: int = REQUEST_BYTE_LIMIT, | ||
) -> xarray.Dataset: | ||
"""Open an Earth Engine ImageCollection as an Xarray Dataset. | ||
|
@@ -936,6 +946,7 @@ def open_dataset( | |
primary_dim_name=primary_dim_name, | ||
primary_dim_property=primary_dim_property, | ||
mask_value=ee_mask_value, | ||
executor_kwargs=executor_kwargs, | ||
request_byte_limit=request_byte_limit, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -360,6 +360,18 @@ def test_data_sanity_check(self): | |
self.assertNotEqual(temperature_2m.min(), 0.0) | ||
self.assertNotEqual(temperature_2m.max(), 0.0) | ||
|
||
def test_open_dataset_with_executor_kwargs(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the test! |
||
executor_kwargs = {'max_workers': 2} | ||
ds = self.entry.open_dataset( | ||
'ee://LANDSAT/LC08/C01/T1', | ||
drop_variables=tuple(f'B{i}' for i in range(3, 12)), | ||
scale=25.0, | ||
n_images=3, | ||
executor_kwargs=executor_kwargs, | ||
) | ||
|
||
self.assertEqual(ds.thread_pool.max_workers, executor_kwargs['max_workers']) | ||
|
||
def test_validate_band_attrs(self): | ||
ds = self.entry.open_dataset( | ||
'ee:LANDSAT/LC08/C01/T1', | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please initialize this at the top or bottom of the
__init__
.