From 3ade095b1120b55adbde03b8c72ae614264f542b Mon Sep 17 00:00:00 2001 From: Isha Arkatkar Date: Tue, 20 Dec 2022 21:24:35 -0800 Subject: [PATCH] Introduce a wrapper C++ API for dtensor device to be used by Tensorflow Federated. PiperOrigin-RevId: 496825787 --- .../cc/core/impl/executors/BUILD | 45 +++ .../cc/core/impl/executors/dtensor_api.cc | 191 ++++++++++++ .../cc/core/impl/executors/dtensor_api.h | 76 +++++ .../core/impl/executors/dtensor_api_test.cc | 281 ++++++++++++++++++ 4 files changed, 593 insertions(+) create mode 100644 tensorflow_federated/cc/core/impl/executors/dtensor_api.cc create mode 100644 tensorflow_federated/cc/core/impl/executors/dtensor_api.h create mode 100644 tensorflow_federated/cc/core/impl/executors/dtensor_api_test.cc diff --git a/tensorflow_federated/cc/core/impl/executors/BUILD b/tensorflow_federated/cc/core/impl/executors/BUILD index 5d7ce10bba..e5c23f5c5a 100644 --- a/tensorflow_federated/cc/core/impl/executors/BUILD +++ b/tensorflow_federated/cc/core/impl/executors/BUILD @@ -196,6 +196,51 @@ cc_test( ], ) +cc_library( + name = "dtensor_api", + srcs = ["dtensor_api.cc"], + hdrs = ["dtensor_api.h"], + deps = [ + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/c:c_api_experimental", + "@org_tensorflow//tensorflow/c:tf_datatype", + "@org_tensorflow//tensorflow/c:tf_status_headers", + "@org_tensorflow//tensorflow/c/eager:c_api", + "@org_tensorflow//tensorflow/c/eager:tfe_op_internal", + "@org_tensorflow//tensorflow/core:core_cpu_base", + "@org_tensorflow//tensorflow/core/common_runtime:core", + "@org_tensorflow//tensorflow/dtensor/cc:dtensor_device_cc", + "@org_tensorflow//tensorflow/dtensor/cc:dtensor_device_util", + "@org_tensorflow//tensorflow/dtensor/cc:mesh_type", + "@org_tensorflow//tensorflow/dtensor/cc:tensor_layout", + ], +) + +cc_test( + name = "dtensor_api_test", + srcs = ["dtensor_api_test.cc"], + deps = [ + ":dtensor_api", + "//tensorflow_federated/cc/common_libs:oss_test_main", + "@com_google_absl//absl/log:check", + "@org_tensorflow//tensorflow/c:tf_status_headers", + "@org_tensorflow//tensorflow/c:tf_tensor_internal", + "@org_tensorflow//tensorflow/c/eager:c_api", + "@org_tensorflow//tensorflow/c/eager:c_api_internal", + "@org_tensorflow//tensorflow/c/eager:immediate_execution_tensor_handle", + "@org_tensorflow//tensorflow/c/eager:tfe_context_internal", + "@org_tensorflow//tensorflow/c/eager:tfe_tensorhandle_internal", + "@org_tensorflow//tensorflow/core:protos_all_cc", + "@org_tensorflow//tensorflow/core/platform:status", + "@org_tensorflow//tensorflow/dtensor/cc:dstatus", + "@org_tensorflow//tensorflow/dtensor/cc:mesh_type", + "@org_tensorflow//tensorflow/dtensor/cc:tensor_layout", + "@org_tensorflow//tensorflow/dtensor/proto:layout_proto_cc", + "@org_tensorflow//tensorflow/tsl/platform:status", + "@org_tensorflow//tensorflow/tsl/platform:statusor", + ], +) + cc_library( name = "dtensor_executor", srcs = ["dtensor_executor.cc"], diff --git a/tensorflow_federated/cc/core/impl/executors/dtensor_api.cc b/tensorflow_federated/cc/core/impl/executors/dtensor_api.cc new file mode 100644 index 0000000000..43b0b5c5d8 --- /dev/null +++ b/tensorflow_federated/cc/core/impl/executors/dtensor_api.cc @@ -0,0 +1,191 @@ +/* Copyright 2022, The TensorFlow Federated Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License +==============================================================================*/ +#include "tensorflow_federated/cc/core/impl/executors/dtensor_api.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/dtensor/cc/dtensor_device.h" +#include "tensorflow/dtensor/cc/dtensor_device_util.h" +#include "tensorflow/dtensor/cc/mesh_type.h" +#include "tensorflow/dtensor/cc/tensor_layout.h" + +extern "C" { + +void* TFE_DTENSOR_RegisterDTensorDevice(TFE_Context* context, + tensorflow::TF_Mesh* mesh, + const char* dtensor_device_name, + TF_Status* status) { + TFE_CustomDevice device; + void* device_info; + tensorflow::dtensor::AllocateDTensorDevice( + /*device_name=*/dtensor_device_name, &device, &device_info); + + std::string mesh_string = tensorflow::unwrap(mesh)->ToString(); + TFE_RegisterCustomDevice(context, device, dtensor_device_name, device_info, + status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + tensorflow::dtensor::AddMesh(mesh_string, device_info, /*is_async=*/false, + /*is_host_mesh=*/false, + /*in_flight_nodes_limit=*/0, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + return device_info; +} + +bool TFE_DTENSOR_IsTensorHandleOnDevice(TFE_Context* context, + TFE_TensorHandle* tensor_handle, + const char* device_name, + TF_Status* status) { + const char* tensor_device = TFE_TensorHandleDeviceName(tensor_handle, status); + if (TF_GetCode(status) != TF_OK) return false; + if (strcmp(tensor_device, device_name) == 0) return true; + return false; +} + +TFE_TensorHandle* TFE_DTENSOR_TensorToDTensor( + TFE_Context* context, TFE_TensorHandle* handle, + const tensorflow::TF_Layout* layout, const char* device_name, + TF_Status* status) { + const tensorflow::dtensor::Layout* layout_object = tensorflow::unwrap(layout); + + if (layout_object->IsFullyReplicated()) { + TFE_TensorHandle* replicated_result = + TFE_DTENSOR_CopyToMesh(context, handle, layout, device_name, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return replicated_result; + } + + // Perform copy to mesh followed by relayout to get result + auto replicated_layout = tensorflow::dtensor::Layout::ReplicatedOnMesh( + layout_object->mesh(), layout_object->rank()); + TFE_TensorHandle* replicated_result = TFE_DTENSOR_CopyToMesh( + context, handle, tensorflow::wrap(&replicated_layout), device_name, + status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_TensorHandle* result = TFE_DTENSOR_Relayout(context, replicated_result, + layout, device_name, status); + // Delete intermediate result handle from copying to mesh. + TFE_DeleteTensorHandle(replicated_result); + return result; +} + +TFE_TensorHandle* TFE_DTENSOR_DTensorToTensor(TFE_Context* context, + TFE_TensorHandle* dtensor_handle, + const char* device_name, + TF_Status* status) { + tensorflow::dtensor::TensorWithLayout* t = + reinterpret_cast( + TFE_TensorHandleDevicePointer(dtensor_handle, status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + + if (t->layout().IsFullyReplicated()) { + // Get the tensor value + return TFE_TensorHandleCopySharingTensor(t->get_tensor(0), status); + } + + auto replicated_layout = tensorflow::dtensor::Layout::ReplicatedOnMesh( + t->layout().mesh(), t->layout().rank()); + + TFE_TensorHandle* result = TFE_DTENSOR_Relayout( + context, dtensor_handle, tensorflow::wrap(&replicated_layout), + device_name, status); + + tensorflow::dtensor::TensorWithLayout* t_replicated = + reinterpret_cast( + TFE_TensorHandleDevicePointer(result, status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + + auto tensor = + TFE_TensorHandleCopySharingTensor(t_replicated->get_tensor(0), status); + + TFE_DeleteTensorHandle(result); + return tensor; +} + +TFE_TensorHandle* TFE_DTENSOR_CopyToMesh(TFE_Context* context, + TFE_TensorHandle* tensor_handle, + const tensorflow::TF_Layout* layout, + const char* device_name, + TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "CopyToMesh", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_OpSetDevice(op.get(), device_name, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + std::string serialized_layout = tensorflow::unwrap(layout)->ToString(); + TFE_OpSetAttrString(op.get(), "layout", serialized_layout.data(), + serialized_layout.length()); + TFE_OpAddInput(op.get(), tensor_handle, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + int num_results = 1; + TFE_TensorHandle* replicated_result; + TFE_Execute(op.get(), &replicated_result, &num_results, status); + + if (TF_GetCode(status) != TF_OK) return nullptr; + + return replicated_result; +} + +TFE_TensorHandle* TFE_DTENSOR_Relayout(TFE_Context* context, + TFE_TensorHandle* handle, + const tensorflow::TF_Layout* layout, + const char* device_name, + TF_Status* status) { + bool is_dtensor = + TFE_DTENSOR_IsTensorHandleOnDevice(context, handle, device_name, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + if (!is_dtensor) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + absl::StrCat("Input to Relayout should be a DTensor on device ", + device_name) + .c_str()); + return nullptr; + } + std::unique_ptr relayout( + TFE_NewOp(context, "Relayout", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(relayout.get(), device_name, status); + + if (TF_GetCode(status) != TF_OK) return nullptr; + + std::string serialized_layout = tensorflow::unwrap(layout)->ToString(); + TFE_OpSetAttrString(relayout.get(), "layout", serialized_layout.data(), + serialized_layout.length()); + TFE_OpAddInput(relayout.get(), handle, status); + + if (TF_GetCode(status) != TF_OK) return nullptr; + + int num_results = 1; + TFE_TensorHandle* result; + TFE_Execute(relayout.get(), &result, &num_results, status); + + if (TF_GetCode(status) != TF_OK) return nullptr; + return result; +} +} diff --git a/tensorflow_federated/cc/core/impl/executors/dtensor_api.h b/tensorflow_federated/cc/core/impl/executors/dtensor_api.h new file mode 100644 index 0000000000..42748b9d1e --- /dev/null +++ b/tensorflow_federated/cc/core/impl/executors/dtensor_api.h @@ -0,0 +1,76 @@ +/* Copyright 2022, The TensorFlow Federated Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_ +#define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/dtensor/cc/mesh_type.h" + +extern "C" { + +// Registers a DTensor device with provided mesh. +// Returns a DeviceInfo object which can be used to add mesh +void* TFE_DTENSOR_RegisterDTensorDevice(TFE_Context* context, + tensorflow::TF_Mesh* mesh, + const char* dtensor_device_name, + TF_Status* status); + +// Returns true, if given tensor_handle points to a DTensor on provided device +// name. +bool TFE_DTENSOR_IsTensorHandleOnDevice(TFE_Context* context, + TFE_TensorHandle* tensor_handle, + const char* device_name, + TF_Status* status); + +// Copies a Tensor to DTensor by sharding or replicating the input tensor +// according to specified layout. +TFE_TensorHandle* TFE_DTENSOR_TensorToDTensor( + TFE_Context* context, TFE_TensorHandle* tensor_handle, + const tensorflow::TF_Layout* layout, const char* device_name, + TF_Status* status); + +// Copies input DTensor to Tensor, by removing the sharding and +// returns the global tensor value handle. +TFE_TensorHandle* TFE_DTENSOR_DTensorToTensor(TFE_Context* context, + TFE_TensorHandle* dtensor_handle, + const char* device_name, + TF_Status* status); + +// Copies a Tensor onto mesh with replicated layout and returns DTensor. +// CopyToMesh only supports replicated layout. +// Input handle to CopyToMesh is expected to be a regular tensor. +TFE_TensorHandle* TFE_DTENSOR_CopyToMesh(TFE_Context* context, + TFE_TensorHandle* tensor_handle, + const tensorflow::TF_Layout* layout, + const char* device_name, + TF_Status* status); + +// Changes the layout of input DTensor to provided layout and returns resulting +// DTensor handle. +// Note that input handle is expected to be DTensor handle, passing a regular +// tensor to Relayout will result in a invalid argument status. +// TODO(b/256948367): Relayout does not support complex dtypes and some dtypes +// on GPU. Add documentation on supported types and fix the support for dtypes. +TFE_TensorHandle* TFE_DTENSOR_Relayout(TFE_Context* context, + TFE_TensorHandle* handle, + const tensorflow::TF_Layout* layout, + const char* device_name, + TF_Status* status); + +} /* end extern "C" */ + +#endif // THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_DTENSOR_API_H_ diff --git a/tensorflow_federated/cc/core/impl/executors/dtensor_api_test.cc b/tensorflow_federated/cc/core/impl/executors/dtensor_api_test.cc new file mode 100644 index 0000000000..893675f98b --- /dev/null +++ b/tensorflow_federated/cc/core/impl/executors/dtensor_api_test.cc @@ -0,0 +1,281 @@ +/* Copyright 2022, The TensorFlow Federated Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License +==============================================================================*/ +#include "tensorflow_federated/cc/core/impl/executors/dtensor_api.h" + +#include // NOLINT +#include +#include +#include +#include +#include +#include +#include + +#include "googlemock/include/gmock/gmock.h" +#include "googletest/include/gtest/gtest.h" +#include "absl/log/check.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor_internal.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/dtensor/cc/dstatus.h" +#include "tensorflow/dtensor/cc/mesh_type.h" +#include "tensorflow/dtensor/cc/tensor_layout.h" +#include "tensorflow/dtensor/proto/layout.pb.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow_federated { +namespace dtensor { +namespace { +using ::testing::HasSubstr; + +tensorflow::StatusOr ShardedOnFirstDimLayout( + int rank, const tensorflow::dtensor::Mesh& mesh) { + std::vector sharding_specs; + sharding_specs.push_back("x"); + for (int i = 1; i < rank; ++i) { + sharding_specs.push_back(tensorflow::dtensor::Layout::kUnshardedDim); + } + return tensorflow::dtensor::Layout::GetLayout(sharding_specs, mesh); +} + +tensorflow::dtensor::MeshProto CreateMeshForTest() { + tensorflow::dtensor::MeshProto mesh; + + tensorflow::dtensor::MeshDimensionProto* dimension = + mesh.add_mesh_dimensions(); + dimension->set_name("x"); + dimension->set_size(2); + mesh.add_local_devices("/job:localhost/replica:0/task:0/device:CPU:0"); + mesh.add_local_devices("/job:localhost/replica:0/task:0/device:CPU:1"); + mesh.add_global_devices("/job:localhost/replica:0/task:0/device:CPU:0"); + mesh.add_global_devices("/job:localhost/replica:0/task:0/device:CPU:0"); + mesh.add_local_device_ids(0); + mesh.add_local_device_ids(1); + mesh.add_global_device_ids(0); + mesh.add_global_device_ids(1); + return mesh; +} + +tensorflow::Tensor CreateIntTensor(tensorflow::TensorShape shape, + const std::vector& elements) { + CHECK(shape.num_elements() == elements.size()); + tensorflow::Tensor tensor(tensorflow::DT_INT32, shape); + auto flat = tensor.flat(); + for (size_t i = 0; i < elements.size(); i++) { + flat(i) = elements[i]; + } + return tensor; +} + +class DTensorAPITest : public ::testing::Test { + public: + DTensorAPITest() { + TF_Status* status = TF_NewStatus(); + std::unique_ptr + opts(TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status), TFE_DeleteContext); + + TFE_SetLogicalCpuDevices(context.get(), 2, + "/job:localhost/replica:0/task:0", status); + device_name_ = "/job:localhost/replica:0/task:0/device:CUSTOM:1"; + TF_DeleteStatus(status); + } + + std::string device_name_; +}; + +TEST_F(DTensorAPITest, CheckTensorDTensorWithShardedLayout) { + TF_Status* status = TF_NewStatus(); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status), TFE_DeleteContext); + TFE_SetLogicalCpuDevices(context.get(), 2, "/job:localhost/replica:0/task:0", + status); + + TF_ASSERT_OK_AND_ASSIGN(auto mesh, tensorflow::dtensor::Mesh::ParseFromProto( + CreateMeshForTest())); + TFE_DTENSOR_RegisterDTensorDevice(context.get(), tensorflow::wrap(&mesh), + device_name_.c_str(), status); + + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_ASSERT_OK_AND_ASSIGN(auto layout, ShardedOnFirstDimLayout(1, mesh)); + + tensorflow::Tensor tensor = + CreateIntTensor(tensorflow::TensorShape({4}), {1, 2, 3, 4}); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + + TFE_TensorHandle* dtensor_handle = TFE_DTENSOR_TensorToDTensor( + context.get(), tensor_handle, tensorflow::wrap(&layout), + device_name_.c_str(), status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + tensorflow::ImmediateExecutionTensorHandle* dtensor = + tensorflow::unwrap(dtensor_handle); + std::string summary; + ASSERT_TRUE(dtensor->SummarizeValue(summary).ok()); + EXPECT_THAT(summary, AllOf(HasSubstr("{\"CPU:0\": [1 2], \"CPU:1\": [3 4]}"), + HasSubstr("x"))); + EXPECT_THAT(dtensor->DebugString(), + AllOf(HasSubstr("dtype=DT_INT32"), + HasSubstr("{\"CPU:0\": [1 2], \"CPU:1\": [3 4]}"), + HasSubstr("sharding_specs:x"))); + + TFE_TensorHandle* converted_tensor_handle = TFE_DTENSOR_DTensorToTensor( + context.get(), dtensor_handle, device_name_.c_str(), status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + std::unique_ptr value_tensor( + TFE_TensorHandleResolve(converted_tensor_handle, status), + TF_DeleteTensor); + + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + int32_t expected[4]; + memcpy(&expected[0], TF_TensorData(value_tensor.get()), + TF_TensorByteSize(value_tensor.get())); + EXPECT_EQ(1, expected[0]); + EXPECT_EQ(2, expected[1]); + EXPECT_EQ(3, expected[2]); + EXPECT_EQ(4, expected[3]); + + TF_DeleteStatus(status); + TFE_DeleteTensorHandle(tensor_handle); + TFE_DeleteTensorHandle(dtensor_handle); + TFE_DeleteTensorHandle(converted_tensor_handle); +} + +TEST_F(DTensorAPITest, CheckCopyToMesh) { + TF_Status* status = TF_NewStatus(); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status), TFE_DeleteContext); + TFE_SetLogicalCpuDevices(context.get(), 2, "/job:localhost/replica:0/task:0", + status); + + TF_ASSERT_OK_AND_ASSIGN(auto mesh, tensorflow::dtensor::Mesh::ParseFromProto( + CreateMeshForTest())); + TFE_DTENSOR_RegisterDTensorDevice(context.get(), tensorflow::wrap(&mesh), + device_name_.c_str(), status); + + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + auto layout = tensorflow::dtensor::Layout::ReplicatedOnMesh(mesh, 1); + + tensorflow::Tensor tensor = + CreateIntTensor(tensorflow::TensorShape({4}), {1, 2, 3, 4}); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + + TFE_TensorHandle* dtensor_handle = TFE_DTENSOR_CopyToMesh( + context.get(), tensor_handle, tensorflow::wrap(&layout), + device_name_.c_str(), status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + tensorflow::ImmediateExecutionTensorHandle* dtensor = + tensorflow::unwrap(dtensor_handle); + std::string summary; + ASSERT_TRUE(dtensor->SummarizeValue(summary).ok()); + EXPECT_THAT(summary, AllOf(HasSubstr("[1 2 3 4]"), HasSubstr("unsharded"))); + EXPECT_THAT(dtensor->DebugString(), + AllOf(HasSubstr("dtype=DT_INT32"), HasSubstr("[1 2 3 4]"), + HasSubstr("sharding_specs:unsharded"))); + + // Reading tensor value from replicated DTensor is allowed. + std::unique_ptr value_tensor( + TFE_TensorHandleResolve(dtensor_handle, status), TF_DeleteTensor); + + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + int32_t expected[4]; + memcpy(&expected[0], TF_TensorData(value_tensor.get()), + TF_TensorByteSize(value_tensor.get())); + EXPECT_EQ(1, expected[0]); + EXPECT_EQ(2, expected[1]); + EXPECT_EQ(3, expected[2]); + EXPECT_EQ(4, expected[3]); + + TF_DeleteStatus(status); + TFE_DeleteTensorHandle(tensor_handle); + TFE_DeleteTensorHandle(dtensor_handle); +} + +TEST_F(DTensorAPITest, CheckRelayout) { + TF_Status* status = TF_NewStatus(); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status), TFE_DeleteContext); + TFE_SetLogicalCpuDevices(context.get(), 2, "/job:localhost/replica:0/task:0", + status); + TF_ASSERT_OK_AND_ASSIGN(auto mesh, tensorflow::dtensor::Mesh::ParseFromProto( + CreateMeshForTest())); + TFE_DTENSOR_RegisterDTensorDevice(context.get(), tensorflow::wrap(&mesh), + device_name_.c_str(), status); + + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + tensorflow::Tensor tensor = + CreateIntTensor(tensorflow::TensorShape({4}), {1, 2, 3, 4}); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + + auto replicated_layout = + tensorflow::dtensor::Layout::ReplicatedOnMesh(mesh, 1); + TFE_TensorHandle* dtensor_handle = TFE_DTENSOR_CopyToMesh( + context.get(), tensor_handle, tensorflow::wrap(&replicated_layout), + device_name_.c_str(), status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + tensorflow::ImmediateExecutionTensorHandle* dtensor = + tensorflow::unwrap(dtensor_handle); + std::string summary; + ASSERT_TRUE(dtensor->SummarizeValue(summary).ok()); + EXPECT_THAT(summary, AllOf(HasSubstr("[1 2 3 4]"), HasSubstr("unsharded"))); + EXPECT_THAT(dtensor->DebugString(), + AllOf(HasSubstr("dtype=DT_INT32"), HasSubstr("[1 2 3 4]"), + HasSubstr("sharding_specs:unsharded"))); + + TF_ASSERT_OK_AND_ASSIGN(auto layout, ShardedOnFirstDimLayout(1, mesh)); + TFE_TensorHandle* relayout_dtensor_handle = TFE_DTENSOR_Relayout( + context.get(), dtensor_handle, tensorflow::wrap(&layout), + device_name_.c_str(), status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + tensorflow::ImmediateExecutionTensorHandle* dtensor1 = + tensorflow::unwrap(relayout_dtensor_handle); + std::string summary1; + ASSERT_TRUE(dtensor1->SummarizeValue(summary1).ok()); + EXPECT_THAT(summary1, AllOf(HasSubstr("{\"CPU:0\": [1 2], \"CPU:1\": [3 4]}"), + HasSubstr("x"))); + EXPECT_THAT(dtensor1->DebugString(), + AllOf(HasSubstr("dtype=DT_INT32"), + HasSubstr("{\"CPU:0\": [1 2], \"CPU:1\": [3 4]}"), + HasSubstr("sharding_specs:x"))); + + TF_DeleteStatus(status); + TFE_DeleteTensorHandle(tensor_handle); + TFE_DeleteTensorHandle(dtensor_handle); + TFE_DeleteTensorHandle(relayout_dtensor_handle); +} +} // namespace +} // namespace dtensor +} // namespace tensorflow_federated