Skip to content

Commit

Permalink
This adds an Earth Engine initialization check for dataset operations…
Browse files Browse the repository at this point in the history
… for remote workers to call Earth Engine. Also adds docs for submitting a Dataflow job using Xee

Dataflow jobs would fail with Xee due to the remote workers not having the EE client library initialized. This adds a check to all calls for the `EarthEngineBackendArray` object so that if there is a call to EE, it will be initialized if not already.

There was discussion on issue #99 regarding documentation for how to do initialize/authenticate on distributed cluster and this also includes a Dataflow example where that users can start from.

close #51

PiperOrigin-RevId: 596966033
  • Loading branch information
Xee authors committed Jan 9, 2024
1 parent f05e82b commit a174bfe
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 1 deletion.
5 changes: 5 additions & 0 deletions examples/dataflow/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
FROM apache/beam_python3.9_sdk:2.51.0

COPY requirements.txt ./

RUN pip install -r requirements.txt
Empty file added examples/dataflow/README.md
Empty file.
4 changes: 4 additions & 0 deletions examples/dataflow/cloudbuild.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
steps:
- name: 'gcr.io/cloud-builders/docker'
args: [ 'build', '-t', 'REGION-docker.pkg.dev/YOUR_PROJECT/REPO/CONTAINER', '.' ]
images: ['REGION-docker.pkg.dev/YOUR_PROJECT/REPO/CONTAINER']
133 changes: 133 additions & 0 deletions examples/dataflow/ee_to_zarr_dataflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports EE ImageCollections to Zarr using Xarray-Beam."""

# example:
# python ee_to_zarr.py
# --input NASA/GPM_L3/IMERG_V06
# --output gs://xee-out-${PROJECT_NUMBER}
# --target_chunks='time=6'
# --runner DataflowRunner
# --project $PROJECT
# --region $REGION
# --temp_location gs://xee-out-${PROJECT_NUMBER}/tmp/
# --service_account_email $SERVICE_ACCOUNT
# --sdk_location=container
# --sdk_container_image=${REGION}-docker.pkg.dev/${PROJECT_NAME}/${REPO}/${CONTAINER}
# --subnetwork regions/${REGION}/subnetworks/${NETWORK_NAME}
# --job_name imerg-dataflow-test-$(date '+%Y%m%d%H%M%S')

import logging

from absl import app
from absl import flags
import apache_beam as beam
from apache_beam.internal import pickler
import xarray as xr
import xarray_beam as xbeam
import xee

import ee

pickler.set_library(pickler.USE_CLOUDPICKLE)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


_INPUT = flags.DEFINE_string(
'input', '', help='The input Earth Engine ImageCollection.'
)
_CRS = flags.DEFINE_string(
'crs',
'EPSG:4326',
help='Coordinate Reference System for output Zarr.',
)
_SCALE = flags.DEFINE_float('scale', 0.25, help='Scale factor for output Zarr.')
_TARGET_CHUNKS = flags.DEFINE_string(
'target_chunks',
'',
help=(
'chunks on the input Zarr dataset to change on the outputs, in the '
'form of a comma separated dimension=size pairs, e.g., '
"--target_chunks='x=10,y=10'. Omitted dimensions are not changed and a "
'chunksize of -1 indicates not to chunk a dimension.'
),
)
_OUTPUT = flags.DEFINE_string('output', '', help='The output zarr path.')
_RUNNER = flags.DEFINE_string('runner', None, help='beam.runners.Runner')


# pylint: disable=unused-argument
def parse_dataflow_flags(argv: list[str]):
parser = flags.argparse_flags.ArgumentParser(
description='parser for dataflow flags',
allow_abbrev=False,
)
_, dataflow_args = parser.parse_known_args()
return dataflow_args


# Borrowed from the xbeam examples:
# https://github.com/google/xarray-beam/blob/4f4fcb965a65b5d577601af311d0e0142ee38076/examples/xbeam_rechunk.py#L41
def _parse_chunks_str(chunks_str: str) -> dict[str, int]:
chunks = {}
parts = chunks_str.split(',')
for part in parts:
k, v = part.split('=')
chunks[k] = int(v)
return chunks


def main(argv: list[str]) -> None:
assert _INPUT.value, 'Must specify --input'
assert _OUTPUT.value, 'Must specify --output'

source_chunks = {'time': 24}
target_chunks = dict(source_chunks, **_parse_chunks_str(_TARGET_CHUNKS.value))

ee.Initialize()

input_coll = (
ee.ImageCollection(_INPUT.value)
.limit(100, 'system:time_start', True)
.select('precipitationCal')
)

ds = xr.open_dataset(
input_coll,
crs=_CRS.value,
scale=_SCALE.value,
engine=xee.EarthEngineBackendEntrypoint,
)
template = xbeam.make_template(ds)
itemsize = max(variable.dtype.itemsize for variable in template.values())

with beam.Pipeline(runner=_RUNNER.value, argv=argv) as root:
_ = (
root
| xbeam.DatasetToChunks(ds, source_chunks)
| xbeam.Rechunk(
ds.sizes,
source_chunks,
target_chunks,
itemsize=itemsize,
)
| xbeam.ChunksToZarr(_OUTPUT.value, template, target_chunks)
)


if __name__ == '__main__':
app.run(main, flags_parser=parse_dataflow_flags)
8 changes: 8 additions & 0 deletions examples/dataflow/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
absl-py
earthengine-api
gcsfs
google-cloud
google-cloud-storage
xarray
xarray-beam
apache-beam[gcp]
52 changes: 51 additions & 1 deletion xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import warnings

import affine
import google.auth
import numpy as np
import pandas as pd
import pyproj
Expand Down Expand Up @@ -144,6 +145,8 @@ def open(
primary_dim_property: Optional[str] = None,
mask_value: Optional[float] = None,
request_byte_limit: int = REQUEST_BYTE_LIMIT,
init_kwargs: Optional[Dict[str, Any]] = None,
try_auto_init: bool = False,
) -> 'EarthEngineStore':
if mode != 'r':
raise ValueError(
Expand All @@ -162,6 +165,8 @@ def open(
primary_dim_property=primary_dim_property,
mask_value=mask_value,
request_byte_limit=request_byte_limit,
init_kwargs=init_kwargs,
try_auto_init=try_auto_init,
)

def __init__(
Expand All @@ -177,7 +182,13 @@ def __init__(
primary_dim_property: Optional[str] = None,
mask_value: Optional[float] = None,
request_byte_limit: int = REQUEST_BYTE_LIMIT,
init_kwargs: Optional[Dict[str, Any]] = None,
try_auto_init: bool = False,
):

self.init_kwargs = init_kwargs
self.try_auto_init = try_auto_init

self.image_collection = image_collection
if n_images != -1:
self.image_collection = image_collection.limit(n_images)
Expand Down Expand Up @@ -723,7 +734,32 @@ def __init__(self, variable_name: str, ee_store: EarthEngineStore):
if isinstance(self.store.chunks, dict):
self._apparent_chunks = self.store.chunks.copy()

def _ee_init_check(self):
if not ee.data.is_initialized() and self.store.try_auto_init:
warnings.warn(
'Earth Engine is not initialized on worker. '
'Attempting to initialize using application default credentials'
)

# Check if init_kwargs are provided.
# If not get an empty dict to store default credentials.
if self.store.init_kwargs is None:
kwargs = {}
else:
kwargs = self.store.init_kwargs

# Get the default credentials.
# Use this over google.auth.compute_credentials in case this happens
# where the compute service account is not available (i.e. local workers).
credentials, _ = google.auth.default()
# Set the credentials keyword with the default worker credentials.
# This overrides the value provided by user.
kwargs['credentials'] = credentials

ee.Initialize(**kwargs)

def __getitem__(self, key: indexing.ExplicitIndexer) -> np.typing.ArrayLike:
self._ee_init_check()
return indexing.explicit_indexing_adapter(
key,
self.shape,
Expand All @@ -748,6 +784,7 @@ def _key_to_slices(
A `key` tuple where all elements are `slice`s, and a `squeeze_axes` tuple,
which can be used as the second argument to np.squeeze().
"""
self._ee_init_check()
key_new = []
squeeze_axes = []
for axis, k in enumerate(key):
Expand All @@ -762,6 +799,7 @@ def _slice_collection(self, image_slice: slice) -> ee.Image:
"""Reduce the ImageCollection into an Image with bands as index slices."""
# Get the right range of Images in the collection, either a single image or
# a range of images...
self._ee_init_check()
start, stop, stride = image_slice.indices(self.shape[0])

# If the input images have IDs, just slice them. Otherwise, we need to do
Expand Down Expand Up @@ -792,6 +830,7 @@ def reduce_bands(x, acc):
def _raw_indexing_method(
self, key: Tuple[Union[int, slice], ...]
) -> np.typing.ArrayLike:
self._ee_init_check()
key, squeeze_axes = self._key_to_slices(key)

# TODO(#13): honor step increments
Expand All @@ -817,7 +856,7 @@ def _raw_indexing_method(

# Here, we break up the requested bounding box into smaller bounding boxes
# that are at most the chunk size. We will divide up the requests for
# pixels across a thread pool. We then need to combine all the arrays into
# pixels across a threfad pool. We then need to combine all the arrays into
# one big array.
#
# Lucky for us, Numpy provides a specialized "concat"-like operation for
Expand Down Expand Up @@ -858,6 +897,7 @@ def _make_tile(
self, tile_index: Tuple[types.TileIndex, types.BBox3d]
) -> Tuple[types.TileIndex, np.ndarray]:
"""Get a numpy array from EE for a specific 3D bounding box (a 'tile')."""
self._ee_init_check()
tile_idx, (istart, iend, *bbox) = tile_index
target_image = self._slice_collection(slice(istart, iend))
return tile_idx, self.store.image_to_array(
Expand All @@ -868,6 +908,7 @@ def _tile_indexes(
self, index_range: slice, bbox: types.BBox
) -> Iterable[Tuple[types.TileIndex, types.BBox3d]]:
"""Calculate indexes to break up a (3D) bounding box into chunks."""
self._ee_init_check()
tstep = self._apparent_chunks['index']
wstep = self._apparent_chunks['width']
hstep = self._apparent_chunks['height']
Expand Down Expand Up @@ -931,6 +972,8 @@ def open_dataset(
primary_dim_property: Optional[str] = None,
ee_mask_value: Optional[float] = None,
request_byte_limit: int = REQUEST_BYTE_LIMIT,
try_auto_init: bool = False,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> xarray.Dataset: # type: ignore
"""Open an Earth Engine ImageCollection as an Xarray Dataset.
Expand Down Expand Up @@ -995,6 +1038,11 @@ def open_dataset(
this is 'np.iinfo(np.int32).max' i.e. 2147483647.
request_byte_limit: the max allowed bytes to request at a time from Earth
Engine. By default, it is 48MBs.
try_auto_init: boolean flag to set if auto initialize for Earth
Engine should be attempted. Set to True if using distributed compute
frameworks.
init_kwargs: keywords to pass to Earth Engine Initialize when attempting
to auto init for remote workers.
Returns:
An xarray.Dataset that streams in remote data from Earth Engine.
Expand Down Expand Up @@ -1022,6 +1070,8 @@ def open_dataset(
primary_dim_property=primary_dim_property,
mask_value=ee_mask_value,
request_byte_limit=request_byte_limit,
init_kwargs=init_kwargs,
try_auto_init=try_auto_init,
)

store_entrypoint = backends_store.StoreBackendEntrypoint()
Expand Down

0 comments on commit a174bfe

Please sign in to comment.