From 0d740364d9410ee52c02bc0a81c4f55ef026f172 Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Wed, 18 Jan 2023 13:35:45 -0800 Subject: [PATCH] Fix a bug in XlaExecutor to always expect results as tuples (possibly of length 1), to match MLIR based JAX tracing. PiperOrigin-RevId: 502965936 --- .../cc/core/impl/executors/BUILD | 7 +- .../cc/core/impl/executors/xla_executor.cc | 22 ++++- .../core/impl/executors/xla_executor_test.cc | 38 +++++++- .../python/core/backends/native/BUILD | 1 + .../python/core/backends/xla/BUILD | 6 +- .../backends/xla/cpp_execution_contexts.py | 25 +++-- .../xla/cpp_execution_contexts_test.py | 95 +++++++++++-------- 7 files changed, 135 insertions(+), 59 deletions(-) diff --git a/tensorflow_federated/cc/core/impl/executors/BUILD b/tensorflow_federated/cc/core/impl/executors/BUILD index a0488a1c12..1bc540f1b5 100644 --- a/tensorflow_federated/cc/core/impl/executors/BUILD +++ b/tensorflow_federated/cc/core/impl/executors/BUILD @@ -1013,7 +1013,7 @@ cc_library( ], ) -cc_test( +tff_cc_cpu_gpu_test( name = "xla_executor_test", timeout = "short", srcs = ["xla_executor_test.cc"], @@ -1027,12 +1027,17 @@ cc_test( "//tensorflow_federated/proto/v0:computation_cc_proto", "@com_google_absl//absl/status", "@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit", # buildcleaner: keep # Linking in this dependency ensures that XLA can compile its code for the CPU host. + "@org_tensorflow//tensorflow/compiler/jit:xla_gpu_jit", # buildcleaner: keep # Linking in this dependency ensures that XLA can compile its code for the GPU. "@org_tensorflow//tensorflow/compiler/tf2xla:common", "@org_tensorflow//tensorflow/compiler/xla:shape_util", "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto_cc", "@org_tensorflow//tensorflow/compiler/xla/client:xla_builder", "@org_tensorflow//tensorflow/compiler/xla/client:xla_computation", + "@org_tensorflow//tensorflow/compiler/xla/service:platform_util", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor:stream_executor_headers", + "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cuda_platform", # buildcleaner: keep # Linking in the host platform here ensures that the stream executor can execute on GPU. "@org_tensorflow//tensorflow/compiler/xla/stream_executor/host:host_platform", # buildcleaner: keep # Linking in the host platform here ensures that the stream executor can execute on CPU. + "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow_federated/cc/core/impl/executors/xla_executor.cc b/tensorflow_federated/cc/core/impl/executors/xla_executor.cc index b1440dc62b..dd825e4d5e 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_executor.cc @@ -627,11 +627,24 @@ class XLAExecutor : public ExecutorBase { const v0::Xla::Binding& result_binding = fn->result_binding(); switch (result_binding.binding_case()) { case v0::Xla::Binding::kTensor: { - // Note that we assume the serialization logic is correct, and that if - // the XLA computation here declares a tensor binding, then it truly - // returns a single value of tensor type. + // JAX tracing always compiles results to be tuples, which would + // result in length 1 tuples. + tensorflow::StatusOr>> + maybe_global_data_vector = xla_client_->DeconstructTuple(**result); + if (!maybe_global_data_vector.ok()) { + return absl::InternalError(absl::StrCat( + "Error destructuring tuple in XLA executor. Message: ", + maybe_global_data_vector.status().error_message())); + } + if (maybe_global_data_vector->size() != 1) { + return absl::InternalError( + absl::StrCat("Expected a 1-tuple representing a single tensor " + "output, instead output was a tuple with", + maybe_global_data_vector->size(), " elements.")); + } return TFFTypeAndGlobalDataToValue( - fn->type().function().result().tensor(), std::move(*result)); + fn->type().function().result().tensor(), + std::move(maybe_global_data_vector.value()[0])); } case v0::Xla::Binding::kStruct: { tensorflow::StatusOr>> @@ -699,6 +712,7 @@ absl::StatusOr GetXLAClient(absl::string_view platform_name) { absl::StatusOr> CreateXLAExecutor( absl::string_view platform_name) { + LOG(INFO) << "Creating XLAExecutor for platform: " << platform_name; xla::Client* client = TFF_TRY(GetXLAClient(platform_name)); return std::make_shared(client); } diff --git a/tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc b/tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc index d7235d9c82..4b117fb731 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc @@ -28,9 +28,12 @@ limitations under the License #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/stream_executor/platform.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/device_factory.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/status_matchers.h" @@ -133,7 +136,25 @@ inline xla::Shape XLAShapeWithUnknownDims(tensorflow::DataType dtype, class XLAExecutorTest : public ::testing::Test { public: - XLAExecutorTest() { test_executor_ = CreateXLAExecutor("Host").value(); } + XLAExecutorTest() { + xla::StatusOr> platforms = + xla::PlatformUtil::GetSupportedPlatforms(); + if (!platforms.ok()) { + LOG(FATAL) << "Could not enumerate supported XLA platforms"; + } + LOG(INFO) << "Found " << platforms->size() << " platforms"; + for (auto* platform : *platforms) { + LOG(INFO) << "Platform: " << platform->Name(); + if (platform->Name() == "CUDA") { + if (platform->VisibleDeviceCount() > 0) { + test_executor_ = CreateXLAExecutor(platform->Name()).value(); + return; + } + } + } + // If no accelerators exist, ust the "Host" (cpu) platform. + test_executor_ = CreateXLAExecutor("Host").value(); + } std::shared_ptr test_executor_; void CheckMaterializeEqual(ValueId id, v0::Value expected_result) { @@ -383,7 +404,12 @@ TEST_F(XLAExecutorTest, CreateValueComputationTensorParameterUnknownRankFails) { TEST_F(XLAExecutorTest, CreateAndMaterializeNoArgCallSingleTensor) { xla::XlaBuilder builder("return_two"); - xla::ConstantR0(&builder, 2.0); + xla::XlaOp constant = xla::ConstantR0(&builder, 2.0); + // To mimic the Python tracing which always returns tuples, event for single + // element results, after passing through MLIR + // (https://github.com/google/jax/blob/38f91bdaa564a4de1e06bde7d191af0bff610bbf/jax/_src/api.py#L958), + // results are always in tuples. + xla::Tuple(&builder, {constant}); tensorflow::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); auto tensor_type = TensorT(v0::TensorType::DT_FLOAT); @@ -459,7 +485,13 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeNoArgCallNestedTensorStructure) { TEST_F(XLAExecutorTest, CreateAndMaterializeIdentityScalar) { xla::XlaBuilder builder("float_scalar_identity"); - xla::Parameter(&builder, 0, xla::ShapeUtil::MakeScalarShape(xla::F32), "x"); + xla::XlaOp parameter = xla::Parameter( + &builder, 0, xla::ShapeUtil::MakeScalarShape(xla::F32), "x"); + // To mimic the Python tracing which always returns tuples, event for single + // element results, after passing through MLIR + // (https://github.com/google/jax/blob/38f91bdaa564a4de1e06bde7d191af0bff610bbf/jax/_src/api.py#L958), + // results are always in tuples. + xla::Tuple(&builder, {parameter}); tensorflow::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); v0::Type float_tensor_type = TensorT(v0::TensorType::DT_FLOAT); diff --git a/tensorflow_federated/python/core/backends/native/BUILD b/tensorflow_federated/python/core/backends/native/BUILD index cb3f6d6d34..f93bead776 100644 --- a/tensorflow_federated/python/core/backends/native/BUILD +++ b/tensorflow_federated/python/core/backends/native/BUILD @@ -3,6 +3,7 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test") package(default_visibility = [ ":native_packages", "//tensorflow_federated/python/core:core_users", + "//tensorflow_federated/python/core/backends/xla:xla_packages", # TODO(b/193543632): Remove this visibility once C++ execution is fully # supported in OSS. diff --git a/tensorflow_federated/python/core/backends/xla/BUILD b/tensorflow_federated/python/core/backends/xla/BUILD index a1961d8384..daa8936b1c 100644 --- a/tensorflow_federated/python/core/backends/xla/BUILD +++ b/tensorflow_federated/python/core/backends/xla/BUILD @@ -88,6 +88,7 @@ py_library( "nokokoro", # b/193543632: C++ execution is not fully supported in OSS. ], deps = [ + "//tensorflow_federated/python/core/backends/native:compiler", "//tensorflow_federated/python/core/impl/computation:computation_base", "//tensorflow_federated/python/core/impl/context_stack:set_default_context", "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", @@ -104,14 +105,11 @@ py_test( ], deps = [ ":cpp_execution_contexts", - "//tensorflow_federated/python/core/impl/computation:computation_impl", "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", "//tensorflow_federated/python/core/impl/federated_context:federated_computation", "//tensorflow_federated/python/core/impl/federated_context:intrinsics", + "//tensorflow_federated/python/core/impl/jax_context:jax_computation", "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/xla_context:xla_serialization", ], ) diff --git a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py index fa3002911c..95d2e3b8b5 100644 --- a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py +++ b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py @@ -13,6 +13,7 @@ # limitations under the License. """Execution contexts for the XLA backend.""" +from tensorflow_federated.python.core.backends.native import compiler as native_compiler from tensorflow_federated.python.core.impl.computation import computation_base from tensorflow_federated.python.core.impl.context_stack import set_default_context from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context @@ -20,16 +21,19 @@ from tensorflow_federated.python.core.impl.executors import executor_bindings -def set_local_cpp_execution_context(default_num_clients: int = 0, - max_concurrent_computation_calls: int = -1): +def set_local_cpp_execution_context( + default_num_clients: int = 0, max_concurrent_computation_calls: int = -1 +): context = create_local_cpp_execution_context( default_num_clients=default_num_clients, - max_concurrent_computation_calls=max_concurrent_computation_calls) + max_concurrent_computation_calls=max_concurrent_computation_calls, + ) set_default_context.set_default_context(context) def create_local_cpp_execution_context( - default_num_clients: int = 0, max_concurrent_computation_calls: int = -1): + default_num_clients: int = 0, max_concurrent_computation_calls: int = -1 +): """Creates a local execution context backed by TFF-C++ runtime. Args: @@ -47,18 +51,19 @@ def create_local_cpp_execution_context( def leaf_executor_fn(max_concurrent_computation_calls): del max_concurrent_computation_calls # Unused. xla_executor_fn = executor_bindings.create_xla_executor() - executor_bindings.create_sequence_executor(xla_executor_fn) + return executor_bindings.create_sequence_executor(xla_executor_fn) factory = cpp_executor_factory.local_cpp_executor_factory( default_num_clients=default_num_clients, max_concurrent_computation_calls=max_concurrent_computation_calls, - leaf_executor_fn=leaf_executor_fn) + leaf_executor_fn=leaf_executor_fn, + ) def compiler_fn(comp: computation_base.Computation): - # TODO(b/255978089): Define compiler_fn - Integrate LocalComputationFactory - # with intrinsic reductions - del comp # Unused. - return None + # TODO(b/255978089): implement lowering to federated_aggregate to create + # JAX computations instead of TensorFlow, similar to "desugar intrinsics" + # in the native backend. + return native_compiler.transform_to_native_form(comp) return sync_execution_context.SyncExecutionContext( executor_fn=factory, compiler_fn=compiler_fn diff --git a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts_test.py b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts_test.py index 6e4a9c8210..61eeb1fd9d 100644 --- a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts_test.py +++ b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts_test.py @@ -13,73 +13,94 @@ # limitations under the License. from absl.testing import absltest -from jax.lib import xla_client import numpy as np from tensorflow_federated.python.core.backends.xla import cpp_execution_contexts -from tensorflow_federated.python.core.impl.computation import computation_impl from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl from tensorflow_federated.python.core.impl.federated_context import federated_computation from tensorflow_federated.python.core.impl.federated_context import intrinsics +from tensorflow_federated.python.core.impl.jax_context import jax_computation from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.xla_context import xla_serialization class CppExecutionContextsTest(absltest.TestCase): + def setUp(self): + super().setUp() + cpp_execution_contexts.set_local_cpp_execution_context() + def test_create_local_execution_context(self): context = cpp_execution_contexts.create_local_cpp_execution_context() self.assertIsInstance(context, context_base.SyncContext) - def test_set_local_cpp_execution_context_and_run_simple_xla_computation(self): - self.skipTest( - 'b/255978089: requires transformation to lowering of TFF computations' - ' to XLA' - ) - builder = xla_client.XlaBuilder('comp') - xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) - xla_client.ops.Constant(builder, np.int32(10)) - xla_comp = builder.build() - comp_type = computation_types.FunctionType(None, np.int32) - comp_pb = xla_serialization.create_xla_tff_computation( - xla_comp, [], comp_type + def test_run_simple_jax_computation(self): + @jax_computation.jax_computation(np.float32, np.float32) + def comp(a, b): + return a + b + + self.assertEqual(comp(np.float32(10.0), np.float32(20.0)), 30.0) + + def test_run_federated_aggergate(self): + @jax_computation.jax_computation(np.float32, np.float32) + def _add(a, b): + return a + b + + @jax_computation.jax_computation(np.float32) + def _identity(a): + return a + + # IMPORTANT: we must wrap the zero literal in a `jax_computation` because + # TFF by default wraps Python literals as `tf.constant` which will fail in + # this execution context. + @jax_computation.jax_computation + def zeros(): + return np.float32(0) + + @federated_computation.federated_computation( + computation_types.at_clients(np.float32) ) - ctx_stack = context_stack_impl.context_stack - comp = computation_impl.ConcreteComputation(comp_pb, ctx_stack) - cpp_execution_contexts.set_local_cpp_execution_context() - self.assertEqual(comp(), 10) + def aggregate(client_values): + return intrinsics.federated_aggregate( + client_values, + zero=zeros(), + accumulate=_add, + merge=_add, + report=_identity, + ) - def test_federated_sum_in_xla_execution_context(self): - self.skipTest( - 'b/255978089: requires transformation to lowering of federated_sum' - ' to XLA' + self.assertEqual( + aggregate([np.float32(1), np.float32(2), np.float32(3)]), np.float32(6) ) + def test_federated_sum(self): @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + computation_types.at_clients(np.int32) ) def comp(x): return intrinsics.federated_sum(x) - cpp_execution_contexts.set_local_cpp_execution_context() - self.assertEqual(comp([1, 2, 3]), 6) - - def test_unweighted_federated_mean_in_xla_execution_context(self): - self.skipTest( - 'b/255978089: requires transformation to lowering of federated_mean' - ' to XLA' - ) + with self.assertRaisesRegex( + Exception, 'Unsupported intrinsic URI: federated_sum' + ): + # TODO(b/255978089): implement intrinsic lowering using JAX computations, + # the compiler currently generates TF logic which will fail. + # self.assertEqual(comp([1, 2, 3]), 6) + comp([1, 2, 3]) + def test_unweighted_federated_mean(self): @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS) + computation_types.at_clients(np.float32) ) def comp(x): return intrinsics.federated_mean(x) - cpp_execution_contexts.set_local_cpp_execution_context() - self.assertEqual(comp([1.0, 2.0, 3.0]), 2.0) + with self.assertRaisesRegex( + Exception, 'Unsupported intrinsic URI: federated_mean' + ): + # TODO(b/255978089): implement intrinsic lowering using JAX computations, + # the compiler currently generates TF logic which will fail. + # self.assertEqual(comp([1.0, 2.0, 3.0]), 2.0) + comp([1.0, 2.0, 3.0]) if __name__ == '__main__':