Skip to content

Commit

Permalink
Introduce a wrapper C++ API for dtensor device to be used by Tensorfl…
Browse files Browse the repository at this point in the history
…ow Federated.

PiperOrigin-RevId: 496825787
  • Loading branch information
ishark authored and tensorflow-copybara committed Dec 21, 2022
1 parent 9f9c673 commit 3ade095
Show file tree
Hide file tree
Showing 4 changed files with 593 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tensorflow_federated/cc/core/impl/executors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,51 @@ cc_test(
],
)

cc_library(
name = "dtensor_api",
srcs = ["dtensor_api.cc"],
hdrs = ["dtensor_api.h"],
deps = [
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/c:c_api_experimental",
"@org_tensorflow//tensorflow/c:tf_datatype",
"@org_tensorflow//tensorflow/c:tf_status_headers",
"@org_tensorflow//tensorflow/c/eager:c_api",
"@org_tensorflow//tensorflow/c/eager:tfe_op_internal",
"@org_tensorflow//tensorflow/core:core_cpu_base",
"@org_tensorflow//tensorflow/core/common_runtime:core",
"@org_tensorflow//tensorflow/dtensor/cc:dtensor_device_cc",
"@org_tensorflow//tensorflow/dtensor/cc:dtensor_device_util",
"@org_tensorflow//tensorflow/dtensor/cc:mesh_type",
"@org_tensorflow//tensorflow/dtensor/cc:tensor_layout",
],
)

cc_test(
name = "dtensor_api_test",
srcs = ["dtensor_api_test.cc"],
deps = [
":dtensor_api",
"//tensorflow_federated/cc/common_libs:oss_test_main",
"@com_google_absl//absl/log:check",
"@org_tensorflow//tensorflow/c:tf_status_headers",
"@org_tensorflow//tensorflow/c:tf_tensor_internal",
"@org_tensorflow//tensorflow/c/eager:c_api",
"@org_tensorflow//tensorflow/c/eager:c_api_internal",
"@org_tensorflow//tensorflow/c/eager:immediate_execution_tensor_handle",
"@org_tensorflow//tensorflow/c/eager:tfe_context_internal",
"@org_tensorflow//tensorflow/c/eager:tfe_tensorhandle_internal",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core/platform:status",
"@org_tensorflow//tensorflow/dtensor/cc:dstatus",
"@org_tensorflow//tensorflow/dtensor/cc:mesh_type",
"@org_tensorflow//tensorflow/dtensor/cc:tensor_layout",
"@org_tensorflow//tensorflow/dtensor/proto:layout_proto_cc",
"@org_tensorflow//tensorflow/tsl/platform:status",
"@org_tensorflow//tensorflow/tsl/platform:statusor",
],
)

cc_library(
name = "dtensor_executor",
srcs = ["dtensor_executor.cc"],
Expand Down
191 changes: 191 additions & 0 deletions tensorflow_federated/cc/core/impl/executors/dtensor_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/* Copyright 2022, The TensorFlow Federated Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License
==============================================================================*/
#include "tensorflow_federated/cc/core/impl/executors/dtensor_api.h"

#include <cstring>
#include <memory>
#include <optional>
#include <string>

#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/dtensor/cc/dtensor_device.h"
#include "tensorflow/dtensor/cc/dtensor_device_util.h"
#include "tensorflow/dtensor/cc/mesh_type.h"
#include "tensorflow/dtensor/cc/tensor_layout.h"

extern "C" {

void* TFE_DTENSOR_RegisterDTensorDevice(TFE_Context* context,
tensorflow::TF_Mesh* mesh,
const char* dtensor_device_name,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::dtensor::AllocateDTensorDevice(
/*device_name=*/dtensor_device_name, &device, &device_info);

std::string mesh_string = tensorflow::unwrap(mesh)->ToString();
TFE_RegisterCustomDevice(context, device, dtensor_device_name, device_info,
status);
if (TF_GetCode(status) != TF_OK) return nullptr;

tensorflow::dtensor::AddMesh(mesh_string, device_info, /*is_async=*/false,
/*is_host_mesh=*/false,
/*in_flight_nodes_limit=*/0, status);
if (TF_GetCode(status) != TF_OK) return nullptr;

return device_info;
}

bool TFE_DTENSOR_IsTensorHandleOnDevice(TFE_Context* context,
TFE_TensorHandle* tensor_handle,
const char* device_name,
TF_Status* status) {
const char* tensor_device = TFE_TensorHandleDeviceName(tensor_handle, status);
if (TF_GetCode(status) != TF_OK) return false;
if (strcmp(tensor_device, device_name) == 0) return true;
return false;
}

TFE_TensorHandle* TFE_DTENSOR_TensorToDTensor(
TFE_Context* context, TFE_TensorHandle* handle,
const tensorflow::TF_Layout* layout, const char* device_name,
TF_Status* status) {
const tensorflow::dtensor::Layout* layout_object = tensorflow::unwrap(layout);

if (layout_object->IsFullyReplicated()) {
TFE_TensorHandle* replicated_result =
TFE_DTENSOR_CopyToMesh(context, handle, layout, device_name, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return replicated_result;
}

// Perform copy to mesh followed by relayout to get result
auto replicated_layout = tensorflow::dtensor::Layout::ReplicatedOnMesh(
layout_object->mesh(), layout_object->rank());
TFE_TensorHandle* replicated_result = TFE_DTENSOR_CopyToMesh(
context, handle, tensorflow::wrap(&replicated_layout), device_name,
status);
if (TF_GetCode(status) != TF_OK) return nullptr;

TFE_TensorHandle* result = TFE_DTENSOR_Relayout(context, replicated_result,
layout, device_name, status);
// Delete intermediate result handle from copying to mesh.
TFE_DeleteTensorHandle(replicated_result);
return result;
}

TFE_TensorHandle* TFE_DTENSOR_DTensorToTensor(TFE_Context* context,
TFE_TensorHandle* dtensor_handle,
const char* device_name,
TF_Status* status) {
tensorflow::dtensor::TensorWithLayout* t =
reinterpret_cast<tensorflow::dtensor::TensorWithLayout*>(
TFE_TensorHandleDevicePointer(dtensor_handle, status));
if (TF_GetCode(status) != TF_OK) return nullptr;

if (t->layout().IsFullyReplicated()) {
// Get the tensor value
return TFE_TensorHandleCopySharingTensor(t->get_tensor(0), status);
}

auto replicated_layout = tensorflow::dtensor::Layout::ReplicatedOnMesh(
t->layout().mesh(), t->layout().rank());

TFE_TensorHandle* result = TFE_DTENSOR_Relayout(
context, dtensor_handle, tensorflow::wrap(&replicated_layout),
device_name, status);

tensorflow::dtensor::TensorWithLayout* t_replicated =
reinterpret_cast<tensorflow::dtensor::TensorWithLayout*>(
TFE_TensorHandleDevicePointer(result, status));
if (TF_GetCode(status) != TF_OK) return nullptr;

auto tensor =
TFE_TensorHandleCopySharingTensor(t_replicated->get_tensor(0), status);

TFE_DeleteTensorHandle(result);
return tensor;
}

TFE_TensorHandle* TFE_DTENSOR_CopyToMesh(TFE_Context* context,
TFE_TensorHandle* tensor_handle,
const tensorflow::TF_Layout* layout,
const char* device_name,
TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "CopyToMesh", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;

TFE_OpSetDevice(op.get(), device_name, status);
if (TF_GetCode(status) != TF_OK) return nullptr;

std::string serialized_layout = tensorflow::unwrap(layout)->ToString();
TFE_OpSetAttrString(op.get(), "layout", serialized_layout.data(),
serialized_layout.length());
TFE_OpAddInput(op.get(), tensor_handle, status);
if (TF_GetCode(status) != TF_OK) return nullptr;

int num_results = 1;
TFE_TensorHandle* replicated_result;
TFE_Execute(op.get(), &replicated_result, &num_results, status);

if (TF_GetCode(status) != TF_OK) return nullptr;

return replicated_result;
}

TFE_TensorHandle* TFE_DTENSOR_Relayout(TFE_Context* context,
TFE_TensorHandle* handle,
const tensorflow::TF_Layout* layout,
const char* device_name,
TF_Status* status) {
bool is_dtensor =
TFE_DTENSOR_IsTensorHandleOnDevice(context, handle, device_name, status);
if (TF_GetCode(status) != TF_OK) return nullptr;

if (!is_dtensor) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
absl::StrCat("Input to Relayout should be a DTensor on device ",
device_name)
.c_str());
return nullptr;
}
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> relayout(
TFE_NewOp(context, "Relayout", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(relayout.get(), device_name, status);

if (TF_GetCode(status) != TF_OK) return nullptr;

std::string serialized_layout = tensorflow::unwrap(layout)->ToString();
TFE_OpSetAttrString(relayout.get(), "layout", serialized_layout.data(),
serialized_layout.length());
TFE_OpAddInput(relayout.get(), handle, status);

if (TF_GetCode(status) != TF_OK) return nullptr;

int num_results = 1;
TFE_TensorHandle* result;
TFE_Execute(relayout.get(), &result, &num_results, status);

if (TF_GetCode(status) != TF_OK) return nullptr;
return result;
}
}
76 changes: 76 additions & 0 deletions tensorflow_federated/cc/core/impl/executors/dtensor_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* Copyright 2022, The TensorFlow Federated Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_
#define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_

#include <string>

#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/dtensor/cc/mesh_type.h"

extern "C" {

// Registers a DTensor device with provided mesh.
// Returns a DeviceInfo object which can be used to add mesh
void* TFE_DTENSOR_RegisterDTensorDevice(TFE_Context* context,
tensorflow::TF_Mesh* mesh,
const char* dtensor_device_name,
TF_Status* status);

// Returns true, if given tensor_handle points to a DTensor on provided device
// name.
bool TFE_DTENSOR_IsTensorHandleOnDevice(TFE_Context* context,
TFE_TensorHandle* tensor_handle,
const char* device_name,
TF_Status* status);

// Copies a Tensor to DTensor by sharding or replicating the input tensor
// according to specified layout.
TFE_TensorHandle* TFE_DTENSOR_TensorToDTensor(
TFE_Context* context, TFE_TensorHandle* tensor_handle,
const tensorflow::TF_Layout* layout, const char* device_name,
TF_Status* status);

// Copies input DTensor to Tensor, by removing the sharding and
// returns the global tensor value handle.
TFE_TensorHandle* TFE_DTENSOR_DTensorToTensor(TFE_Context* context,
TFE_TensorHandle* dtensor_handle,
const char* device_name,
TF_Status* status);

// Copies a Tensor onto mesh with replicated layout and returns DTensor.
// CopyToMesh only supports replicated layout.
// Input handle to CopyToMesh is expected to be a regular tensor.
TFE_TensorHandle* TFE_DTENSOR_CopyToMesh(TFE_Context* context,
TFE_TensorHandle* tensor_handle,
const tensorflow::TF_Layout* layout,
const char* device_name,
TF_Status* status);

// Changes the layout of input DTensor to provided layout and returns resulting
// DTensor handle.
// Note that input handle is expected to be DTensor handle, passing a regular
// tensor to Relayout will result in a invalid argument status.
// TODO(b/256948367): Relayout does not support complex dtypes and some dtypes
// on GPU. Add documentation on supported types and fix the support for dtypes.
TFE_TensorHandle* TFE_DTENSOR_Relayout(TFE_Context* context,
TFE_TensorHandle* handle,
const tensorflow::TF_Layout* layout,
const char* device_name,
TF_Status* status);

} /* end extern "C" */

#endif // THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_
Loading

0 comments on commit 3ade095

Please sign in to comment.