Skip to content

Commit

Permalink
Merge pull request #121 from dabhicusp:tile_improve
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600454677
  • Loading branch information
Xee authors committed Jan 22, 2024
2 parents 31fb949 + 14da6de commit 10380ea
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,18 +562,17 @@ def _get_primary_coordinates(self) -> List[Any]:
return primary_coords

def _get_tile_from_ee(
self, tile_index: Tuple[Any, Union[str, int]]
) -> Tuple[slice, np.ndarray]:
self, tile_and_band: Tuple[Tuple[int, int, int], str]
) -> Tuple[int, np.ndarray[Any, np.dtype]]:
"""Get a numpy array from EE for a specific bounding box (a 'tile')."""
tile_index, band_id = tile_index
(tile_index, tile_coords_start, tile_coords_end), band_id = tile_and_band
bbox = self.project(
(tile_index[0], 0, tile_index[1], 1)
(tile_coords_start, 0, tile_coords_end, 1)
if band_id == 'x'
else (0, tile_index[0], 1, tile_index[1])
else (0, tile_coords_start, 1, tile_coords_end)
)
tile_idx = slice(tile_index[0], tile_index[1])
target_image = ee.Image.pixelCoordinates(ee.Projection(self.crs_arg))
return tile_idx, self.image_to_array(
return tile_index, self.image_to_array(
target_image, grid=bbox, dtype=np.float32, bandIds=[band_id]
)

Expand All @@ -586,7 +585,7 @@ def _process_coordinate_data(
) -> np.ndarray:
"""Process coordinate data using multithreading for longitude or latitude."""
data = [
(tile_size * i, min(tile_size * (i + 1), end_point))
(i, tile_size * i, min(tile_size * (i + 1), end_point))
for i in range(tile_count)
]
tiles = [None] * tile_count
Expand All @@ -595,8 +594,8 @@ def _process_coordinate_data(
self._get_tile_from_ee,
list(zip(data, itertools.cycle([coordinate_type]))),
):
tiles[i] = arr.tolist() if coordinate_type == 'x' else arr.tolist()[0]
return np.concatenate(tiles)
tiles[i] = arr.flatten()
return np.array(tiles)

def get_variables(self) -> utils.Frozen[str, xarray.Variable]:
vars_ = [(name, self.open_store_variable(name)) for name in self._bands()]
Expand Down

0 comments on commit 10380ea

Please sign in to comment.