diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 3563b0432..a7d381034 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -744,18 +744,6 @@ cc_binary( }), ) -cc_binary( - name = "python/ops/libtensorflow_io_plugins.so", - copts = tf_io_copts(), - linkshared = 1, - deps = select({ - "//tensorflow_io/core:static_build_on": [], - "//conditions:default": [ - "//tensorflow_io/core/plugins:plugins", - ], - }), -) - cc_binary( name = "python/ops/libtensorflow_io_golang.so", copts = tf_io_copts(), diff --git a/tensorflow_io/gcs/BUILD b/tensorflow_io/gcs/BUILD new file mode 100644 index 000000000..1c66d0860 --- /dev/null +++ b/tensorflow_io/gcs/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load( + "//:tools/build/tensorflow_io.bzl", + "tf_io_copts", +) + +cc_library( + name = "gcs_config_ops", + srcs = [ + "kernels/gcs_config_op_kernels.cc", + "ops/gcs_config_ops.cc", + ], + copts = tf_io_copts(), + linkstatic = True, + deps = [ + "@curl", + "@jsoncpp_git//:jsoncpp", + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow_io/gcs/README.md b/tensorflow_io/gcs/README.md new file mode 100644 index 000000000..99782a341 --- /dev/null +++ b/tensorflow_io/gcs/README.md @@ -0,0 +1,3 @@ +## Cloud Storage (GCS) ## + +The Google Cloud Storage ops allow the user to configure the GCS File System. diff --git a/tensorflow_io/gcs/__init__.py b/tensorflow_io/gcs/__init__.py new file mode 100644 index 000000000..39f6154b7 --- /dev/null +++ b/tensorflow_io/gcs/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module for cloud ops.""" + + +from tensorflow.python.util.all_util import remove_undocumented + +# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top +from tensorflow_io.gcs.python.ops.gcs_config_ops import * + +_allowed_symbols = [ + "configure_colab_session", + "configure_gcs", + "BlockCacheParams", + "ConfigureGcsHook", +] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow_io/gcs/kernels/gcs_config_op_kernels.cc b/tensorflow_io/gcs/kernels/gcs_config_op_kernels.cc new file mode 100644 index 000000000..3fd878a73 --- /dev/null +++ b/tensorflow_io/gcs/kernels/gcs_config_op_kernels.cc @@ -0,0 +1,206 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + tstring json_string; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = + tensorflow::MakeUnique(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ TF_GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ TF_GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow_io/gcs/ops/gcs_config_ops.cc b/tensorflow_io/gcs/ops/gcs_config_ops.cc new file mode 100644 index 000000000..140dbc3a3 --- /dev/null +++ b/tensorflow_io/gcs/ops/gcs_config_ops.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("IO>GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. +The json input can be of the format: +1. Refresh Token: +{ + "client_id": "", + "client_secret": "", + "refresh_token: "", + "type": "authorized_user", +} +2. Service Account: +{ + "type": "service_account", + "project_id": "", + "private_key_id": "", + "private_key": "------BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY------\n", + "client_email": "@.iam.gserviceaccount.com", + "client_id": "", + # Some additional fields elided +} +Note the credentials established through this method are shared across all +sessions run on this runtime. +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("IO>GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow_io/gcs/python/__init__.py b/tensorflow_io/gcs/python/__init__.py new file mode 100644 index 000000000..f00d24fd2 --- /dev/null +++ b/tensorflow_io/gcs/python/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains Python API methods for GCS integration.""" diff --git a/tensorflow_io/gcs/python/ops/__init__.py b/tensorflow_io/gcs/python/ops/__init__.py new file mode 100644 index 000000000..568c0e67a --- /dev/null +++ b/tensorflow_io/gcs/python/ops/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains the Python API methods for GCS integration.""" diff --git a/tensorflow_io/gcs/python/ops/gcs_config_ops.py b/tensorflow_io/gcs/python/ops/gcs_config_ops.py new file mode 100644 index 000000000..148602fe1 --- /dev/null +++ b/tensorflow_io/gcs/python/ops/gcs_config_ops.py @@ -0,0 +1,235 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + + +import json +import os + +import tensorflow as tf +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training +from tensorflow_io.core.python.ops import core_ops + +# Some GCS operations may be pre-defined and available via tf.contrib in +# earlier TF versions. Because these ops are pre-registered, they will not be +# visible from the _gcs_config_ops library. In this case we use the tf.contrib +# version instead. +tf_v1 = tf.version.VERSION.startswith("1") + +if not tf_v1: + gcs_configure_credentials = core_ops.io_gcs_configure_credentials + gcs_configure_block_cache = core_ops.io_gcs_configure_block_cache + + +class BlockCacheParams: # pylint: disable=useless-object-inheritance + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if "refresh_token" in creds_dict or "private_key" in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError( + "credentials was not a well formed JSON string.", e + ) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + "field." + ) + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError( + 'credentials has neither a "refresh_token" nor a ' + '"private_key" field.' + ) + credentials = json.dumps(credentials) + else: + raise ValueError("credentials is of an unknown type") + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError("block_cache must be an instance of BlockCacheParams.") + self._block_cache = block_cache + + def begin(self): + """Called once before using the session. + + When called, the default graph is the one that will be launched in the + session. The hook can modify the graph by adding new operations to it. + After the `begin()` call the graph will be finalized and the other callbacks + can not modify the graph anymore. Second call of `begin()` on the same + graph, should not change the graph. + """ + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_op = gcs_configure_credentials( + self._credentials_placeholder + ) + else: + self._credentials_op = None + + if self._block_cache: + self._block_cache_op = gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness, + ) + else: + self._block_cache_op = None + + def after_create_session(self, session, coord): + """Called when new TensorFlow session is created. + + This is called to signal the hooks that a new session has been created. This + has two essential differences with the situation in which `begin` is called: + + * When this is called, the graph is finalized and ops can no longer be added + to the graph. + * This method will also be called as a result of recovering a wrapped + session, not only at the beginning of the overall session. + + Args: + session: A TensorFlow Session that has been created. + coord: A Coordinator object which keeps track of all threads. + """ + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}, + ) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def _configure_gcs_tfv2(credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Args: + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + gcs_configure_credentials(credentials) + + if block_cache: + gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness, + ) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def _configure_colab_session_tfv2(): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + """ + # Read from the application default credentials (adc). + adc_filename = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "/content/adc.json") + with open(adc_filename) as f: + data = json.load(f) + configure_gcs(credentials=data) + + +if tf_v1: + configure_gcs = tf.contrib.cloud.configure_gcs + configure_colab_session = tf.contrib.cloud.configure_colab_session +else: + configure_gcs = _configure_gcs_tfv2 + configure_colab_session = _configure_colab_session_tfv2 diff --git a/tests/test_gcs_config_ops.py b/tests/test_gcs_config_ops.py new file mode 100644 index 000000000..7d0140900 --- /dev/null +++ b/tests/test_gcs_config_ops.py @@ -0,0 +1,46 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the gcs_config_ops.""" + + +import sys +import pytest + +import tensorflow as tf + +from tensorflow.python.platform import test +from tensorflow_io import gcs + +tf_v1 = tf.version.VERSION.startswith("1") + + +class GcsConfigOpsTest(test.TestCase): + """GCS Config OPS test""" + + @pytest.mark.skipif(sys.platform == "darwin", reason=None) + def test_set_block_cache(self): + """test_set_block_cache""" + cfg = gcs.BlockCacheParams(max_bytes=1024 * 1024 * 1024) + if tf_v1: + with tf.Session() as session: + gcs.configure_gcs( + session, credentials=None, block_cache=cfg, device=None + ) + else: + gcs.configure_gcs(block_cache=cfg) + + +if __name__ == "__main__": + test.main()