Skip to content

Commit

Permalink
Revert "Deprecate gcs-config (#1024)"
Browse files Browse the repository at this point in the history
This reverts commit 9702a15.
  • Loading branch information
michaelbanfield committed Dec 15, 2020
1 parent 59ddbc4 commit ca8e327
Show file tree
Hide file tree
Showing 10 changed files with 642 additions and 12 deletions.
12 changes: 0 additions & 12 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -744,18 +744,6 @@ cc_binary(
}),
)

cc_binary(
name = "python/ops/libtensorflow_io_plugins.so",
copts = tf_io_copts(),
linkshared = 1,
deps = select({
"//tensorflow_io/core:static_build_on": [],
"//conditions:default": [
"//tensorflow_io/core/plugins:plugins",
],
}),
)

cc_binary(
name = "python/ops/libtensorflow_io_golang.so",
copts = tf_io_copts(),
Expand Down
25 changes: 25 additions & 0 deletions tensorflow_io/gcs/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//visibility:public"])

load(
"//:tools/build/tensorflow_io.bzl",
"tf_io_copts",
)

cc_library(
name = "gcs_config_ops",
srcs = [
"kernels/gcs_config_op_kernels.cc",
"ops/gcs_config_ops.cc",
],
copts = tf_io_copts(),
linkstatic = True,
deps = [
"@curl",
"@jsoncpp_git//:jsoncpp",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
alwayslink = 1,
)
3 changes: 3 additions & 0 deletions tensorflow_io/gcs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Cloud Storage (GCS) ##

The Google Cloud Storage ops allow the user to configure the GCS File System.
29 changes: 29 additions & 0 deletions tensorflow_io/gcs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for cloud ops."""


from tensorflow.python.util.all_util import remove_undocumented

# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top
from tensorflow_io.gcs.python.ops.gcs_config_ops import *

_allowed_symbols = [
"configure_colab_session",
"configure_gcs",
"BlockCacheParams",
"ConfigureGcsHook",
]
remove_undocumented(__name__, _allowed_symbols)
206 changes: 206 additions & 0 deletions tensorflow_io/gcs/kernels/gcs_config_op_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <sstream>

#include "include/json/json.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/cloud/curl_http_request.h"
#include "tensorflow/core/platform/cloud/gcs_file_system.h"
#include "tensorflow/core/platform/cloud/oauth_client.h"
#include "tensorflow/core/util/ptr_util.h"

namespace tensorflow {
namespace {

// The default initial delay between retries with exponential backoff.
constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec

// The minimum time delta between now and the token expiration time
// for the token to be re-used.
constexpr int kExpirationTimeMarginSec = 60;

// The URL to retrieve the auth bearer token via OAuth with a refresh token.
constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token";

// The URL to retrieve the auth bearer token via OAuth with a private key.
constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token";

// The authentication token scope to request.
constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform";

Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) {
DCHECK(fs != nullptr);
*fs = nullptr;

FileSystem* filesystem = nullptr;
TF_RETURN_IF_ERROR(
ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem));
if (filesystem == nullptr) {
return errors::FailedPrecondition("The GCS file system is not registered.");
}

*fs = dynamic_cast<RetryingGcsFileSystem*>(filesystem);
if (*fs == nullptr) {
return errors::Internal(
"The filesystem registered under the 'gs://' scheme was not a "
"tensorflow::RetryingGcsFileSystem*.");
}
return Status::OK();
}

template <typename T>
Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name,
T* output) {
const Tensor* argument_t;
TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
return errors::InvalidArgument(argument_name, " must be a scalar");
}
*output = argument_t->scalar<T>()();
return Status::OK();
}

// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem.
class GcsCredentialsOpKernel : public OpKernel {
public:
explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// Get a handle to the GCS file system.
RetryingGcsFileSystem* gcs = nullptr;
OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs));

tstring json_string;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<tstring>(ctx, "json", &json_string));

Json::Value json;
Json::Reader reader;
std::stringstream json_stream(json_string);
OP_REQUIRES(ctx, reader.parse(json_stream, json),
errors::InvalidArgument("Could not parse json: ", json_string));

OP_REQUIRES(
ctx, json.isMember("refresh_token") || json.isMember("private_key"),
errors::InvalidArgument("JSON format incompatible; did not find fields "
"`refresh_token` or `private_key`."));

auto provider =
tensorflow::MakeUnique<ConstantAuthProvider>(json, ctx->env());

// Test getting a token
string dummy_token;
OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token));
OP_REQUIRES(ctx, !dummy_token.empty(),
errors::InvalidArgument(
"Could not retrieve a token with the given credentials."));

// Set the provider.
gcs->underlying()->SetAuthProvider(std::move(provider));
}

private:
class ConstantAuthProvider : public AuthProvider {
public:
ConstantAuthProvider(const Json::Value& json,
std::unique_ptr<OAuthClient> oauth_client, Env* env,
int64 initial_retry_delay_usec)
: json_(json),
oauth_client_(std::move(oauth_client)),
env_(env),
initial_retry_delay_usec_(initial_retry_delay_usec) {}

ConstantAuthProvider(const Json::Value& json, Env* env)
: ConstantAuthProvider(json, tensorflow::MakeUnique<OAuthClient>(), env,
kInitialRetryDelayUsec) {}

~ConstantAuthProvider() override {}

Status GetToken(string* token) override {
mutex_lock l(mu_);
const uint64 now_sec = env_->NowSeconds();

if (!current_token_.empty() &&
now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
*token = current_token_;
return Status::OK();
}
if (json_.isMember("refresh_token")) {
TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson(
json_, kOAuthV3Url, &current_token_, &expiration_timestamp_sec_));
} else if (json_.isMember("private_key")) {
TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson(
json_, kOAuthV4Url, kOAuthScope, &current_token_,
&expiration_timestamp_sec_));
} else {
return errors::FailedPrecondition(
"Unexpected content of the JSON credentials file.");
}

*token = current_token_;
return Status::OK();
}

private:
Json::Value json_;
std::unique_ptr<OAuthClient> oauth_client_;
Env* env_;

mutex mu_;
string current_token_ TF_GUARDED_BY(mu_);
uint64 expiration_timestamp_sec_ TF_GUARDED_BY(mu_) = 0;

// The initial delay for exponential backoffs when retrying failed calls.
const int64 initial_retry_delay_usec_;
TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider);
};
};

REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureCredentials").Device(DEVICE_CPU),
GcsCredentialsOpKernel);

class GcsBlockCacheOpKernel : public OpKernel {
public:
explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// Get a handle to the GCS file system.
RetryingGcsFileSystem* gcs = nullptr;
OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs));

size_t max_cache_size, block_size, max_staleness;
OP_REQUIRES_OK(ctx, ParseScalarArgument<size_t>(ctx, "max_cache_size",
&max_cache_size));
OP_REQUIRES_OK(ctx,
ParseScalarArgument<size_t>(ctx, "block_size", &block_size));
OP_REQUIRES_OK(
ctx, ParseScalarArgument<size_t>(ctx, "max_staleness", &max_staleness));

if (gcs->underlying()->block_size() == block_size &&
gcs->underlying()->max_bytes() == max_cache_size &&
gcs->underlying()->max_staleness() == max_staleness) {
LOG(INFO) << "Skipping resetting the GCS block cache.";
return;
}
gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size,
max_staleness);
}
};

REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureBlockCache").Device(DEVICE_CPU),
GcsBlockCacheOpKernel);

} // namespace
} // namespace tensorflow
66 changes: 66 additions & 0 deletions tensorflow_io/gcs/ops/gcs_config_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {

using shape_inference::InferenceContext;

REGISTER_OP("IO>GcsConfigureCredentials")
.Input("json: string")
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Configures the credentials used by the GCS client of the local TF runtime.
The json input can be of the format:
1. Refresh Token:
{
"client_id": "<redacted>",
"client_secret": "<redacted>",
"refresh_token: "<redacted>",
"type": "authorized_user",
}
2. Service Account:
{
"type": "service_account",
"project_id": "<redacted>",
"private_key_id": "<redacted>",
"private_key": "------BEGIN PRIVATE KEY-----\n<REDACTED>\n-----END PRIVATE KEY------\n",
"client_email": "<REDACTED>@<REDACTED>.iam.gserviceaccount.com",
"client_id": "<REDACTED>",
# Some additional fields elided
}
Note the credentials established through this method are shared across all
sessions run on this runtime.
Note be sure to feed the inputs to this op to ensure the credentials are not
stored in a constant op within the graph that might accidentally be checkpointed
or in other ways be persisted or exfiltrated.
)doc");

REGISTER_OP("IO>GcsConfigureBlockCache")
.Input("max_cache_size: uint64")
.Input("block_size: uint64")
.Input("max_staleness: uint64")
.SetShapeFn(shape_inference::NoOutputs)
.Doc(R"doc(
Re-configures the GCS block cache with the new configuration values.
If the values are the same as already configured values, this op is a no-op. If
they are different, the current contents of the block cache is dropped, and a
new block cache is created fresh.
)doc");

} // namespace tensorflow
16 changes: 16 additions & 0 deletions tensorflow_io/gcs/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""This module contains Python API methods for GCS integration."""
16 changes: 16 additions & 0 deletions tensorflow_io/gcs/python/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""This module contains the Python API methods for GCS integration."""
Loading

0 comments on commit ca8e327

Please sign in to comment.