diff --git a/xee/ext.py b/xee/ext.py index f96c365..1b50dc9 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -172,8 +172,15 @@ def __init__( primary_dim_name: Optional[str] = None, primary_dim_property: Optional[str] = None, mask_value: Optional[float] = None, + executor_kwargs: Optional[dict] = None, request_byte_limit: int = REQUEST_BYTE_LIMIT, ): + # Initialize executor_kwargs + if executor_kwargs is None: + self.executor_kwargs = {} + else: + self.executor_kwargs = executor_kwargs + self.image_collection = image_collection if n_images != -1: self.image_collection = image_collection.limit(n_images) @@ -203,7 +210,12 @@ def __init__( coordinates=f'{self.primary_dim_name} {x_dim_name} {y_dim_name}', crs=self.crs_arg, ) - self._props = self._make_attrs_valid(self._props) + # Initialize executor_kwargs + if executor_kwargs is None: + self.executor_kwargs = {} + else: + self.executor_kwargs = executor_kwargs + # Scale in the projection's units. Typically, either meters or degrees. # If we use the default CRS i.e. EPSG:3857, the units is in meters. default_scale = self.SCALE_UNITS.get(self.scale_units, 1) @@ -756,8 +768,9 @@ def _raw_indexing_method( for _ in range(shape[0]) ] - # TODO(#11): Allow users to configure this via kwargs. - with concurrent.futures.ThreadPoolExecutor() as pool: + + # Pass executor_kwargs to ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor(**self.executor_kwargs) as pool: for (i, j, k), arr in pool.map( self._make_tile, self._tile_indexes(key[0], bbox) ): @@ -845,6 +858,7 @@ def open_dataset( primary_dim_name: Optional[str] = None, primary_dim_property: Optional[str] = None, ee_mask_value: Optional[float] = None, + executor_kwargs: Optional[dict] = None, request_byte_limit: int = REQUEST_BYTE_LIMIT, ) -> xarray.Dataset: """Open an Earth Engine ImageCollection as an Xarray Dataset. @@ -908,8 +922,8 @@ def open_dataset( 'system:time_start'. ee_mask_value (optional): Value to mask to EE nodata values. By default, this is 'np.iinfo(np.int32).max' i.e. 2147483647. - request_byte_limit: the max allowed bytes to request at a time from Earth - Engine. By default, it is 48MBs. + executor_kwargs (optional): A dictionary of keyword arguments to pass to + the ThreadPoolExecutor that handles the parallel computation of pixels. Returns: An xarray.Dataset that streams in remote data from Earth Engine. @@ -936,6 +950,7 @@ def open_dataset( primary_dim_name=primary_dim_name, primary_dim_property=primary_dim_property, mask_value=ee_mask_value, + executor_kwargs=executor_kwargs, request_byte_limit=request_byte_limit, ) diff --git a/xee/ext_integration_test.py b/xee/ext_integration_test.py index b44f0d5..dfc7407 100644 --- a/xee/ext_integration_test.py +++ b/xee/ext_integration_test.py @@ -360,6 +360,18 @@ def test_data_sanity_check(self): self.assertNotEqual(temperature_2m.min(), 0.0) self.assertNotEqual(temperature_2m.max(), 0.0) + def test_open_dataset_with_executor_kwargs(self): + executor_kwargs = {'max_workers': 2} + ds = self.entry.open_dataset( + 'ee://LANDSAT/LC08/C01/T1', + drop_variables=tuple(f'B{i}' for i in range(3, 12)), + scale=25.0, + n_images=3, + executor_kwargs=executor_kwargs, + ) + + self.assertEqual(ds.thread_pool.max_workers, executor_kwargs['max_workers']) + def test_validate_band_attrs(self): ds = self.entry.open_dataset( 'ee:LANDSAT/LC08/C01/T1',