Skip to content

Commit

Permalink
Fix a bug in XlaExecutor to always expect results as tuples (possibly…
Browse files Browse the repository at this point in the history
… of length 1), to match MLIR based JAX tracing.

PiperOrigin-RevId: 502965936
  • Loading branch information
ZacharyGarrett authored and tensorflow-copybara committed Jan 18, 2023
1 parent 4829c26 commit 0d74036
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 59 deletions.
7 changes: 6 additions & 1 deletion tensorflow_federated/cc/core/impl/executors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ cc_library(
],
)

cc_test(
tff_cc_cpu_gpu_test(
name = "xla_executor_test",
timeout = "short",
srcs = ["xla_executor_test.cc"],
Expand All @@ -1027,12 +1027,17 @@ cc_test(
"//tensorflow_federated/proto/v0:computation_cc_proto",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit", # buildcleaner: keep # Linking in this dependency ensures that XLA can compile its code for the CPU host.
"@org_tensorflow//tensorflow/compiler/jit:xla_gpu_jit", # buildcleaner: keep # Linking in this dependency ensures that XLA can compile its code for the GPU.
"@org_tensorflow//tensorflow/compiler/tf2xla:common",
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto_cc",
"@org_tensorflow//tensorflow/compiler/xla/client:xla_builder",
"@org_tensorflow//tensorflow/compiler/xla/client:xla_computation",
"@org_tensorflow//tensorflow/compiler/xla/service:platform_util",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor:stream_executor_headers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cuda_platform", # buildcleaner: keep # Linking in the host platform here ensures that the stream executor can execute on GPU.
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/host:host_platform", # buildcleaner: keep # Linking in the host platform here ensures that the stream executor can execute on CPU.
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
Expand Down
22 changes: 18 additions & 4 deletions tensorflow_federated/cc/core/impl/executors/xla_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -627,11 +627,24 @@ class XLAExecutor : public ExecutorBase<ValueFuture> {
const v0::Xla::Binding& result_binding = fn->result_binding();
switch (result_binding.binding_case()) {
case v0::Xla::Binding::kTensor: {
// Note that we assume the serialization logic is correct, and that if
// the XLA computation here declares a tensor binding, then it truly
// returns a single value of tensor type.
// JAX tracing always compiles results to be tuples, which would
// result in length 1 tuples.
tensorflow::StatusOr<std::vector<std::unique_ptr<xla::GlobalData>>>
maybe_global_data_vector = xla_client_->DeconstructTuple(**result);
if (!maybe_global_data_vector.ok()) {
return absl::InternalError(absl::StrCat(
"Error destructuring tuple in XLA executor. Message: ",
maybe_global_data_vector.status().error_message()));
}
if (maybe_global_data_vector->size() != 1) {
return absl::InternalError(
absl::StrCat("Expected a 1-tuple representing a single tensor "
"output, instead output was a tuple with",
maybe_global_data_vector->size(), " elements."));
}
return TFFTypeAndGlobalDataToValue(
fn->type().function().result().tensor(), std::move(*result));
fn->type().function().result().tensor(),
std::move(maybe_global_data_vector.value()[0]));
}
case v0::Xla::Binding::kStruct: {
tensorflow::StatusOr<std::vector<std::unique_ptr<xla::GlobalData>>>
Expand Down Expand Up @@ -699,6 +712,7 @@ absl::StatusOr<xla::Client*> GetXLAClient(absl::string_view platform_name) {

absl::StatusOr<std::shared_ptr<Executor>> CreateXLAExecutor(
absl::string_view platform_name) {
LOG(INFO) << "Creating XLAExecutor for platform: " << platform_name;
xla::Client* client = TFF_TRY(GetXLAClient(platform_name));
return std::make_shared<XLAExecutor>(client);
}
Expand Down
38 changes: 35 additions & 3 deletions tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ limitations under the License
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/stream_executor/platform.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/device_factory.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow_federated/cc/core/impl/executors/executor.h"
#include "tensorflow_federated/cc/core/impl/executors/status_matchers.h"
Expand Down Expand Up @@ -133,7 +136,25 @@ inline xla::Shape XLAShapeWithUnknownDims(tensorflow::DataType dtype,

class XLAExecutorTest : public ::testing::Test {
public:
XLAExecutorTest() { test_executor_ = CreateXLAExecutor("Host").value(); }
XLAExecutorTest() {
xla::StatusOr<std::vector<xla::se::Platform*>> platforms =
xla::PlatformUtil::GetSupportedPlatforms();
if (!platforms.ok()) {
LOG(FATAL) << "Could not enumerate supported XLA platforms";
}
LOG(INFO) << "Found " << platforms->size() << " platforms";
for (auto* platform : *platforms) {
LOG(INFO) << "Platform: " << platform->Name();
if (platform->Name() == "CUDA") {
if (platform->VisibleDeviceCount() > 0) {
test_executor_ = CreateXLAExecutor(platform->Name()).value();
return;
}
}
}
// If no accelerators exist, ust the "Host" (cpu) platform.
test_executor_ = CreateXLAExecutor("Host").value();
}
std::shared_ptr<Executor> test_executor_;

void CheckMaterializeEqual(ValueId id, v0::Value expected_result) {
Expand Down Expand Up @@ -383,7 +404,12 @@ TEST_F(XLAExecutorTest, CreateValueComputationTensorParameterUnknownRankFails) {

TEST_F(XLAExecutorTest, CreateAndMaterializeNoArgCallSingleTensor) {
xla::XlaBuilder builder("return_two");
xla::ConstantR0<float>(&builder, 2.0);
xla::XlaOp constant = xla::ConstantR0<float>(&builder, 2.0);
// To mimic the Python tracing which always returns tuples, event for single
// element results, after passing through MLIR
// (https://github.com/google/jax/blob/38f91bdaa564a4de1e06bde7d191af0bff610bbf/jax/_src/api.py#L958),
// results are always in tuples.
xla::Tuple(&builder, {constant});
tensorflow::StatusOr<xla::XlaComputation> xla_computation = builder.Build();
ASSERT_TRUE(xla_computation.ok());
auto tensor_type = TensorT(v0::TensorType::DT_FLOAT);
Expand Down Expand Up @@ -459,7 +485,13 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeNoArgCallNestedTensorStructure) {

TEST_F(XLAExecutorTest, CreateAndMaterializeIdentityScalar) {
xla::XlaBuilder builder("float_scalar_identity");
xla::Parameter(&builder, 0, xla::ShapeUtil::MakeScalarShape(xla::F32), "x");
xla::XlaOp parameter = xla::Parameter(
&builder, 0, xla::ShapeUtil::MakeScalarShape(xla::F32), "x");
// To mimic the Python tracing which always returns tuples, event for single
// element results, after passing through MLIR
// (https://github.com/google/jax/blob/38f91bdaa564a4de1e06bde7d191af0bff610bbf/jax/_src/api.py#L958),
// results are always in tuples.
xla::Tuple(&builder, {parameter});
tensorflow::StatusOr<xla::XlaComputation> xla_computation = builder.Build();
ASSERT_TRUE(xla_computation.ok());
v0::Type float_tensor_type = TensorT(v0::TensorType::DT_FLOAT);
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/core/backends/native/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(default_visibility = [
":native_packages",
"//tensorflow_federated/python/core:core_users",
"//tensorflow_federated/python/core/backends/xla:xla_packages",

# TODO(b/193543632): Remove this visibility once C++ execution is fully
# supported in OSS.
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_federated/python/core/backends/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ py_library(
"nokokoro", # b/193543632: C++ execution is not fully supported in OSS.
],
deps = [
"//tensorflow_federated/python/core/backends/native:compiler",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/context_stack:set_default_context",
"//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context",
Expand All @@ -104,14 +105,11 @@ py_test(
],
deps = [
":cpp_execution_contexts",
"//tensorflow_federated/python/core/impl/computation:computation_impl",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl",
"//tensorflow_federated/python/core/impl/federated_context:federated_computation",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/jax_context:jax_computation",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/xla_context:xla_serialization",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,27 @@
# limitations under the License.
"""Execution contexts for the XLA backend."""

from tensorflow_federated.python.core.backends.native import compiler as native_compiler
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.context_stack import set_default_context
from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context
from tensorflow_federated.python.core.impl.executor_stacks import cpp_executor_factory
from tensorflow_federated.python.core.impl.executors import executor_bindings


def set_local_cpp_execution_context(default_num_clients: int = 0,
max_concurrent_computation_calls: int = -1):
def set_local_cpp_execution_context(
default_num_clients: int = 0, max_concurrent_computation_calls: int = -1
):
context = create_local_cpp_execution_context(
default_num_clients=default_num_clients,
max_concurrent_computation_calls=max_concurrent_computation_calls)
max_concurrent_computation_calls=max_concurrent_computation_calls,
)
set_default_context.set_default_context(context)


def create_local_cpp_execution_context(
default_num_clients: int = 0, max_concurrent_computation_calls: int = -1):
default_num_clients: int = 0, max_concurrent_computation_calls: int = -1
):
"""Creates a local execution context backed by TFF-C++ runtime.
Args:
Expand All @@ -47,18 +51,19 @@ def create_local_cpp_execution_context(
def leaf_executor_fn(max_concurrent_computation_calls):
del max_concurrent_computation_calls # Unused.
xla_executor_fn = executor_bindings.create_xla_executor()
executor_bindings.create_sequence_executor(xla_executor_fn)
return executor_bindings.create_sequence_executor(xla_executor_fn)

factory = cpp_executor_factory.local_cpp_executor_factory(
default_num_clients=default_num_clients,
max_concurrent_computation_calls=max_concurrent_computation_calls,
leaf_executor_fn=leaf_executor_fn)
leaf_executor_fn=leaf_executor_fn,
)

def compiler_fn(comp: computation_base.Computation):
# TODO(b/255978089): Define compiler_fn - Integrate LocalComputationFactory
# with intrinsic reductions
del comp # Unused.
return None
# TODO(b/255978089): implement lowering to federated_aggregate to create
# JAX computations instead of TensorFlow, similar to "desugar intrinsics"
# in the native backend.
return native_compiler.transform_to_native_form(comp)

return sync_execution_context.SyncExecutionContext(
executor_fn=factory, compiler_fn=compiler_fn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,73 +13,94 @@
# limitations under the License.

from absl.testing import absltest
from jax.lib import xla_client
import numpy as np

from tensorflow_federated.python.core.backends.xla import cpp_execution_contexts
from tensorflow_federated.python.core.impl.computation import computation_impl
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
from tensorflow_federated.python.core.impl.federated_context import federated_computation
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.jax_context import jax_computation
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.xla_context import xla_serialization


class CppExecutionContextsTest(absltest.TestCase):

def setUp(self):
super().setUp()
cpp_execution_contexts.set_local_cpp_execution_context()

def test_create_local_execution_context(self):
context = cpp_execution_contexts.create_local_cpp_execution_context()
self.assertIsInstance(context, context_base.SyncContext)

def test_set_local_cpp_execution_context_and_run_simple_xla_computation(self):
self.skipTest(
'b/255978089: requires transformation to lowering of TFF computations'
' to XLA'
)
builder = xla_client.XlaBuilder('comp')
xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple()))
xla_client.ops.Constant(builder, np.int32(10))
xla_comp = builder.build()
comp_type = computation_types.FunctionType(None, np.int32)
comp_pb = xla_serialization.create_xla_tff_computation(
xla_comp, [], comp_type
def test_run_simple_jax_computation(self):
@jax_computation.jax_computation(np.float32, np.float32)
def comp(a, b):
return a + b

self.assertEqual(comp(np.float32(10.0), np.float32(20.0)), 30.0)

def test_run_federated_aggergate(self):
@jax_computation.jax_computation(np.float32, np.float32)
def _add(a, b):
return a + b

@jax_computation.jax_computation(np.float32)
def _identity(a):
return a

# IMPORTANT: we must wrap the zero literal in a `jax_computation` because
# TFF by default wraps Python literals as `tf.constant` which will fail in
# this execution context.
@jax_computation.jax_computation
def zeros():
return np.float32(0)

@federated_computation.federated_computation(
computation_types.at_clients(np.float32)
)
ctx_stack = context_stack_impl.context_stack
comp = computation_impl.ConcreteComputation(comp_pb, ctx_stack)
cpp_execution_contexts.set_local_cpp_execution_context()
self.assertEqual(comp(), 10)
def aggregate(client_values):
return intrinsics.federated_aggregate(
client_values,
zero=zeros(),
accumulate=_add,
merge=_add,
report=_identity,
)

def test_federated_sum_in_xla_execution_context(self):
self.skipTest(
'b/255978089: requires transformation to lowering of federated_sum'
' to XLA'
self.assertEqual(
aggregate([np.float32(1), np.float32(2), np.float32(3)]), np.float32(6)
)

def test_federated_sum(self):
@federated_computation.federated_computation(
computation_types.FederatedType(np.int32, placements.CLIENTS)
computation_types.at_clients(np.int32)
)
def comp(x):
return intrinsics.federated_sum(x)

cpp_execution_contexts.set_local_cpp_execution_context()
self.assertEqual(comp([1, 2, 3]), 6)

def test_unweighted_federated_mean_in_xla_execution_context(self):
self.skipTest(
'b/255978089: requires transformation to lowering of federated_mean'
' to XLA'
)
with self.assertRaisesRegex(
Exception, 'Unsupported intrinsic URI: federated_sum'
):
# TODO(b/255978089): implement intrinsic lowering using JAX computations,
# the compiler currently generates TF logic which will fail.
# self.assertEqual(comp([1, 2, 3]), 6)
comp([1, 2, 3])

def test_unweighted_federated_mean(self):
@federated_computation.federated_computation(
computation_types.FederatedType(np.float32, placements.CLIENTS)
computation_types.at_clients(np.float32)
)
def comp(x):
return intrinsics.federated_mean(x)

cpp_execution_contexts.set_local_cpp_execution_context()
self.assertEqual(comp([1.0, 2.0, 3.0]), 2.0)
with self.assertRaisesRegex(
Exception, 'Unsupported intrinsic URI: federated_mean'
):
# TODO(b/255978089): implement intrinsic lowering using JAX computations,
# the compiler currently generates TF logic which will fail.
# self.assertEqual(comp([1.0, 2.0, 3.0]), 2.0)
comp([1.0, 2.0, 3.0])


if __name__ == '__main__':
Expand Down

0 comments on commit 0d74036

Please sign in to comment.