-
Notifications
You must be signed in to change notification settings - Fork 284
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "Deprecate gcs-config (#1024)"
This reverts commit 9702a15.
- Loading branch information
1 parent
59ddbc4
commit ca8e327
Showing
10 changed files
with
642 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
licenses(["notice"]) # Apache 2.0 | ||
|
||
package(default_visibility = ["//visibility:public"]) | ||
|
||
load( | ||
"//:tools/build/tensorflow_io.bzl", | ||
"tf_io_copts", | ||
) | ||
|
||
cc_library( | ||
name = "gcs_config_ops", | ||
srcs = [ | ||
"kernels/gcs_config_op_kernels.cc", | ||
"ops/gcs_config_ops.cc", | ||
], | ||
copts = tf_io_copts(), | ||
linkstatic = True, | ||
deps = [ | ||
"@curl", | ||
"@jsoncpp_git//:jsoncpp", | ||
"@local_config_tf//:libtensorflow_framework", | ||
"@local_config_tf//:tf_header_lib", | ||
], | ||
alwayslink = 1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
## Cloud Storage (GCS) ## | ||
|
||
The Google Cloud Storage ops allow the user to configure the GCS File System. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Module for cloud ops.""" | ||
|
||
|
||
from tensorflow.python.util.all_util import remove_undocumented | ||
|
||
# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top | ||
from tensorflow_io.gcs.python.ops.gcs_config_ops import * | ||
|
||
_allowed_symbols = [ | ||
"configure_colab_session", | ||
"configure_gcs", | ||
"BlockCacheParams", | ||
"ConfigureGcsHook", | ||
] | ||
remove_undocumented(__name__, _allowed_symbols) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include <sstream> | ||
|
||
#include "include/json/json.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
#include "tensorflow/core/framework/tensor_shape.h" | ||
#include "tensorflow/core/platform/cloud/curl_http_request.h" | ||
#include "tensorflow/core/platform/cloud/gcs_file_system.h" | ||
#include "tensorflow/core/platform/cloud/oauth_client.h" | ||
#include "tensorflow/core/util/ptr_util.h" | ||
|
||
namespace tensorflow { | ||
namespace { | ||
|
||
// The default initial delay between retries with exponential backoff. | ||
constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec | ||
|
||
// The minimum time delta between now and the token expiration time | ||
// for the token to be re-used. | ||
constexpr int kExpirationTimeMarginSec = 60; | ||
|
||
// The URL to retrieve the auth bearer token via OAuth with a refresh token. | ||
constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; | ||
|
||
// The URL to retrieve the auth bearer token via OAuth with a private key. | ||
constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; | ||
|
||
// The authentication token scope to request. | ||
constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; | ||
|
||
Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { | ||
DCHECK(fs != nullptr); | ||
*fs = nullptr; | ||
|
||
FileSystem* filesystem = nullptr; | ||
TF_RETURN_IF_ERROR( | ||
ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); | ||
if (filesystem == nullptr) { | ||
return errors::FailedPrecondition("The GCS file system is not registered."); | ||
} | ||
|
||
*fs = dynamic_cast<RetryingGcsFileSystem*>(filesystem); | ||
if (*fs == nullptr) { | ||
return errors::Internal( | ||
"The filesystem registered under the 'gs://' scheme was not a " | ||
"tensorflow::RetryingGcsFileSystem*."); | ||
} | ||
return Status::OK(); | ||
} | ||
|
||
template <typename T> | ||
Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, | ||
T* output) { | ||
const Tensor* argument_t; | ||
TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); | ||
if (!TensorShapeUtils::IsScalar(argument_t->shape())) { | ||
return errors::InvalidArgument(argument_name, " must be a scalar"); | ||
} | ||
*output = argument_t->scalar<T>()(); | ||
return Status::OK(); | ||
} | ||
|
||
// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. | ||
class GcsCredentialsOpKernel : public OpKernel { | ||
public: | ||
explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} | ||
void Compute(OpKernelContext* ctx) override { | ||
// Get a handle to the GCS file system. | ||
RetryingGcsFileSystem* gcs = nullptr; | ||
OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); | ||
|
||
tstring json_string; | ||
OP_REQUIRES_OK(ctx, | ||
ParseScalarArgument<tstring>(ctx, "json", &json_string)); | ||
|
||
Json::Value json; | ||
Json::Reader reader; | ||
std::stringstream json_stream(json_string); | ||
OP_REQUIRES(ctx, reader.parse(json_stream, json), | ||
errors::InvalidArgument("Could not parse json: ", json_string)); | ||
|
||
OP_REQUIRES( | ||
ctx, json.isMember("refresh_token") || json.isMember("private_key"), | ||
errors::InvalidArgument("JSON format incompatible; did not find fields " | ||
"`refresh_token` or `private_key`.")); | ||
|
||
auto provider = | ||
tensorflow::MakeUnique<ConstantAuthProvider>(json, ctx->env()); | ||
|
||
// Test getting a token | ||
string dummy_token; | ||
OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); | ||
OP_REQUIRES(ctx, !dummy_token.empty(), | ||
errors::InvalidArgument( | ||
"Could not retrieve a token with the given credentials.")); | ||
|
||
// Set the provider. | ||
gcs->underlying()->SetAuthProvider(std::move(provider)); | ||
} | ||
|
||
private: | ||
class ConstantAuthProvider : public AuthProvider { | ||
public: | ||
ConstantAuthProvider(const Json::Value& json, | ||
std::unique_ptr<OAuthClient> oauth_client, Env* env, | ||
int64 initial_retry_delay_usec) | ||
: json_(json), | ||
oauth_client_(std::move(oauth_client)), | ||
env_(env), | ||
initial_retry_delay_usec_(initial_retry_delay_usec) {} | ||
|
||
ConstantAuthProvider(const Json::Value& json, Env* env) | ||
: ConstantAuthProvider(json, tensorflow::MakeUnique<OAuthClient>(), env, | ||
kInitialRetryDelayUsec) {} | ||
|
||
~ConstantAuthProvider() override {} | ||
|
||
Status GetToken(string* token) override { | ||
mutex_lock l(mu_); | ||
const uint64 now_sec = env_->NowSeconds(); | ||
|
||
if (!current_token_.empty() && | ||
now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { | ||
*token = current_token_; | ||
return Status::OK(); | ||
} | ||
if (json_.isMember("refresh_token")) { | ||
TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( | ||
json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); | ||
} else if (json_.isMember("private_key")) { | ||
TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( | ||
json_, kOAuthV4Url, kOAuthScope, ¤t_token_, | ||
&expiration_timestamp_sec_)); | ||
} else { | ||
return errors::FailedPrecondition( | ||
"Unexpected content of the JSON credentials file."); | ||
} | ||
|
||
*token = current_token_; | ||
return Status::OK(); | ||
} | ||
|
||
private: | ||
Json::Value json_; | ||
std::unique_ptr<OAuthClient> oauth_client_; | ||
Env* env_; | ||
|
||
mutex mu_; | ||
string current_token_ TF_GUARDED_BY(mu_); | ||
uint64 expiration_timestamp_sec_ TF_GUARDED_BY(mu_) = 0; | ||
|
||
// The initial delay for exponential backoffs when retrying failed calls. | ||
const int64 initial_retry_delay_usec_; | ||
TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); | ||
}; | ||
}; | ||
|
||
REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureCredentials").Device(DEVICE_CPU), | ||
GcsCredentialsOpKernel); | ||
|
||
class GcsBlockCacheOpKernel : public OpKernel { | ||
public: | ||
explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} | ||
void Compute(OpKernelContext* ctx) override { | ||
// Get a handle to the GCS file system. | ||
RetryingGcsFileSystem* gcs = nullptr; | ||
OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); | ||
|
||
size_t max_cache_size, block_size, max_staleness; | ||
OP_REQUIRES_OK(ctx, ParseScalarArgument<size_t>(ctx, "max_cache_size", | ||
&max_cache_size)); | ||
OP_REQUIRES_OK(ctx, | ||
ParseScalarArgument<size_t>(ctx, "block_size", &block_size)); | ||
OP_REQUIRES_OK( | ||
ctx, ParseScalarArgument<size_t>(ctx, "max_staleness", &max_staleness)); | ||
|
||
if (gcs->underlying()->block_size() == block_size && | ||
gcs->underlying()->max_bytes() == max_cache_size && | ||
gcs->underlying()->max_staleness() == max_staleness) { | ||
LOG(INFO) << "Skipping resetting the GCS block cache."; | ||
return; | ||
} | ||
gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, | ||
max_staleness); | ||
} | ||
}; | ||
|
||
REGISTER_KERNEL_BUILDER(Name("IO>GcsConfigureBlockCache").Device(DEVICE_CPU), | ||
GcsBlockCacheOpKernel); | ||
|
||
} // namespace | ||
} // namespace tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow/core/framework/common_shape_fns.h" | ||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/shape_inference.h" | ||
|
||
namespace tensorflow { | ||
|
||
using shape_inference::InferenceContext; | ||
|
||
REGISTER_OP("IO>GcsConfigureCredentials") | ||
.Input("json: string") | ||
.SetShapeFn(shape_inference::NoOutputs) | ||
.Doc(R"doc( | ||
Configures the credentials used by the GCS client of the local TF runtime. | ||
The json input can be of the format: | ||
1. Refresh Token: | ||
{ | ||
"client_id": "<redacted>", | ||
"client_secret": "<redacted>", | ||
"refresh_token: "<redacted>", | ||
"type": "authorized_user", | ||
} | ||
2. Service Account: | ||
{ | ||
"type": "service_account", | ||
"project_id": "<redacted>", | ||
"private_key_id": "<redacted>", | ||
"private_key": "------BEGIN PRIVATE KEY-----\n<REDACTED>\n-----END PRIVATE KEY------\n", | ||
"client_email": "<REDACTED>@<REDACTED>.iam.gserviceaccount.com", | ||
"client_id": "<REDACTED>", | ||
# Some additional fields elided | ||
} | ||
Note the credentials established through this method are shared across all | ||
sessions run on this runtime. | ||
Note be sure to feed the inputs to this op to ensure the credentials are not | ||
stored in a constant op within the graph that might accidentally be checkpointed | ||
or in other ways be persisted or exfiltrated. | ||
)doc"); | ||
|
||
REGISTER_OP("IO>GcsConfigureBlockCache") | ||
.Input("max_cache_size: uint64") | ||
.Input("block_size: uint64") | ||
.Input("max_staleness: uint64") | ||
.SetShapeFn(shape_inference::NoOutputs) | ||
.Doc(R"doc( | ||
Re-configures the GCS block cache with the new configuration values. | ||
If the values are the same as already configured values, this op is a no-op. If | ||
they are different, the current contents of the block cache is dropped, and a | ||
new block cache is created fresh. | ||
)doc"); | ||
|
||
} // namespace tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""This module contains Python API methods for GCS integration.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""This module contains the Python API methods for GCS integration.""" |
Oops, something went wrong.