From c860c836cd590481588b654feb856d12c3b2b5cb Mon Sep 17 00:00:00 2001 From: Michael Reneer Date: Wed, 14 Aug 2024 10:19:47 -0700 Subject: [PATCH] Sort dtype switch statements consistently. This change is a follow up to cl/662652837 and should reduce merge conflicts with upcoming changes to serialization (cl/658906580) and make maintaining this code easier. PiperOrigin-RevId: 662966562 --- .../cc/core/impl/executors/array_test_utils.h | 42 ++++++------- .../core/impl/executors/tensorflow_utils.cc | 31 +++++----- .../impl/executors/tensorflow_utils_test.cc | 16 ++--- .../cc/core/impl/executors/xla_utils.cc | 30 +++++----- .../cc/core/impl/executors/xla_utils_test.cc | 26 ++++---- .../python/core/impl/compiler/array.py | 14 ++--- .../python/core/impl/compiler/array_test.py | 60 +++++++++---------- .../core/impl/types/dtype_utils_test.py | 2 +- 8 files changed, 111 insertions(+), 110 deletions(-) diff --git a/tensorflow_federated/cc/core/impl/executors/array_test_utils.h b/tensorflow_federated/cc/core/impl/executors/array_test_utils.h index d3080c2247..0bdd2e87ce 100644 --- a/tensorflow_federated/cc/core/impl/executors/array_test_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/array_test_utils.h @@ -124,21 +124,24 @@ inline absl::StatusOr CreateArray( return array_pb; } -// Overload for Eigen::bfloat16. +// Overload for complex. +template inline absl::StatusOr CreateArray( v0::DataType dtype, v0::ArrayShape shape_pb, - std::initializer_list values) { + std::initializer_list> values) { v0::Array array_pb; array_pb.set_dtype(dtype); array_pb.mutable_shape()->Swap(&shape_pb); + const T* begin = reinterpret_cast(values.begin()); switch (dtype) { - case v0::DataType::DT_BFLOAT16: { - auto size = values.size(); - array_pb.mutable_bfloat16_list()->mutable_value()->Reserve(size); - for (auto element : values) { - array_pb.mutable_bfloat16_list()->mutable_value()->AddAlreadyReserved( - Eigen::numext::bit_cast(element)); - } + case v0::DataType::DT_COMPLEX64: { + array_pb.mutable_complex64_list()->mutable_value()->Assign( + begin, begin + values.size() * 2); + break; + } + case v0::DataType::DT_COMPLEX128: { + array_pb.mutable_complex128_list()->mutable_value()->Assign( + begin, begin + values.size() * 2); break; } default: @@ -148,24 +151,21 @@ inline absl::StatusOr CreateArray( return array_pb; } -// Overload for complex. -template +// Overload for Eigen::bfloat16. inline absl::StatusOr CreateArray( v0::DataType dtype, v0::ArrayShape shape_pb, - std::initializer_list> values) { + std::initializer_list values) { v0::Array array_pb; array_pb.set_dtype(dtype); array_pb.mutable_shape()->Swap(&shape_pb); - const T* begin = reinterpret_cast(values.begin()); switch (dtype) { - case v0::DataType::DT_COMPLEX64: { - array_pb.mutable_complex64_list()->mutable_value()->Assign( - begin, begin + values.size() * 2); - break; - } - case v0::DataType::DT_COMPLEX128: { - array_pb.mutable_complex128_list()->mutable_value()->Assign( - begin, begin + values.size() * 2); + case v0::DataType::DT_BFLOAT16: { + auto size = values.size(); + array_pb.mutable_bfloat16_list()->mutable_value()->Reserve(size); + for (auto element : values) { + array_pb.mutable_bfloat16_list()->mutable_value()->AddAlreadyReserved( + Eigen::numext::bit_cast(element)); + } break; } default: diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc index 183669ee50..8a1d507535 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc @@ -89,6 +89,14 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField return Eigen::numext::bit_cast(static_cast(x)); }); } + +// Overload for complex. +template +static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, + std::complex* dest) { + std::copy(src.begin(), src.end(), reinterpret_cast(dest)); +} + // Overload for Eigen::bfloat16. static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, Eigen::bfloat16* dest) { @@ -101,13 +109,6 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField }); } -// Overload for complex. -template -static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, - std::complex* dest) { - std::copy(src.begin(), src.end(), reinterpret_cast(dest)); -} - // Overload for string. static void CopyFromRepeatedField( const google::protobuf::RepeatedPtrField& src, @@ -197,14 +198,6 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kBfloat16List: { - tensorflow::Tensor tensor( - tensorflow::DataTypeToEnum::value, - TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); - CopyFromRepeatedField(array_pb.bfloat16_list().value(), - tensor.flat().data()); - return tensor; - } case v0::Array::kFloat32List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, @@ -237,6 +230,14 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } + case v0::Array::kBfloat16List: { + tensorflow::Tensor tensor( + tensorflow::DataTypeToEnum::value, + TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); + CopyFromRepeatedField(array_pb.bfloat16_list().value(), + tensor.flat().data()); + return tensor; + } case v0::Array::kStringList: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc index a0a1f62001..cfdf431302 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc @@ -214,14 +214,6 @@ INSTANTIATE_TEST_SUITE_P( .value(), tensorflow::test::AsScalar(Eigen::half{1.0}), }, - { - "bfloat16", - testing::CreateArray(v0::DataType::DT_BFLOAT16, - testing::CreateArrayShape({}), - {Eigen::bfloat16{1.0}}) - .value(), - tensorflow::test::AsScalar(Eigen::bfloat16{1.0}), - }, { "float32", testing::CreateArray(v0::DataType::DT_FLOAT, @@ -252,6 +244,14 @@ INSTANTIATE_TEST_SUITE_P( .value(), tensorflow::test::AsScalar(tensorflow::complex128{1.0, 1.0}), }, + { + "bfloat16", + testing::CreateArray(v0::DataType::DT_BFLOAT16, + testing::CreateArrayShape({}), + {Eigen::bfloat16{1.0}}) + .value(), + tensorflow::test::AsScalar(Eigen::bfloat16{1.0}), + }, { "string", testing::CreateArray(v0::DataType::DT_STRING, diff --git a/tensorflow_federated/cc/core/impl/executors/xla_utils.cc b/tensorflow_federated/cc/core/impl/executors/xla_utils.cc index b7e462defa..06da07e101 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_utils.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_utils.cc @@ -121,7 +121,14 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField }); } -// Overload for Eigen::bflot16. +// Overload for complex. +template +static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, + std::complex* dest) { + std::copy(src.begin(), src.end(), reinterpret_cast(dest)); +} + +// Overload for Eigen::bfloat16. static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, Eigen::bfloat16* dest) { // Values of dtype ml_dtypes.bfloat16 are packed to and unpacked from a @@ -133,13 +140,6 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField }); } -// Overload for complex. -template -static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, - std::complex* dest) { - std::copy(src.begin(), src.end(), reinterpret_cast(dest)); -} - absl::StatusOr LiteralFromArray(const v0::Array& array_pb) { switch (array_pb.kind_case()) { case v0::Array::kBoolList: { @@ -212,13 +212,6 @@ absl::StatusOr LiteralFromArray(const v0::Array& array_pb) { literal.data().begin()); return literal; } - case v0::Array::kBfloat16List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_BFLOAT16, array_pb.shape()))); - CopyFromRepeatedField(array_pb.bfloat16_list().value(), - literal.data().begin()); - return literal; - } case v0::Array::kFloat32List: { xla::Literal literal(TFF_TRY( ShapeFromArrayShape(v0::DataType::DT_FLOAT, array_pb.shape()))); @@ -247,6 +240,13 @@ absl::StatusOr LiteralFromArray(const v0::Array& array_pb) { literal.data().begin()); return literal; } + case v0::Array::kBfloat16List: { + xla::Literal literal(TFF_TRY( + ShapeFromArrayShape(v0::DataType::DT_BFLOAT16, array_pb.shape()))); + CopyFromRepeatedField(array_pb.bfloat16_list().value(), + literal.data().begin()); + return literal; + } default: return absl::UnimplementedError( absl::StrCat("Unexpected DataType found:", array_pb.kind_case())); diff --git a/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc b/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc index a36a39d88a..f9b00096fd 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc @@ -239,19 +239,6 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_float16) { EXPECT_EQ(actual_literal, expected_literal); } -TEST(LiteralFromArrayTest, TestReturnsLiteral_bfloat16) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_BFLOAT16, testing::CreateArrayShape({}), - {Eigen::bfloat16{1.0}})); - - const xla::Literal& actual_literal = - TFF_ASSERT_OK(LiteralFromArray(array_pb)); - - xla::Literal expected_literal = - xla::LiteralUtil::CreateR0(Eigen::bfloat16{1.0}); - EXPECT_EQ(actual_literal, expected_literal); -} - TEST(LiteralFromArrayTest, TestReturnsLiteral_float32) { const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( v0::DataType::DT_FLOAT, testing::CreateArrayShape({}), {1.0})); @@ -300,6 +287,19 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_complex128) { EXPECT_EQ(actual_literal, expected_literal); } +TEST(LiteralFromArrayTest, TestReturnsLiteral_bfloat16) { + const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( + v0::DataType::DT_BFLOAT16, testing::CreateArrayShape({}), + {Eigen::bfloat16{1.0}})); + + const xla::Literal& actual_literal = + TFF_ASSERT_OK(LiteralFromArray(array_pb)); + + xla::Literal expected_literal = + xla::LiteralUtil::CreateR0(Eigen::bfloat16{1.0}); + EXPECT_EQ(actual_literal, expected_literal); +} + TEST(LiteralFromArrayTest, TestReturnsLiteral_array) { const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( v0::DataType::DT_INT32, testing::CreateArrayShape({2, 3}), diff --git a/tensorflow_federated/python/core/impl/compiler/array.py b/tensorflow_federated/python/core/impl/compiler/array.py index 6abec43661..dd94e97507 100644 --- a/tensorflow_federated/python/core/impl/compiler/array.py +++ b/tensorflow_federated/python/core/impl/compiler/array.py @@ -68,13 +68,6 @@ def from_proto(array_pb: array_pb2.Array) -> Array: # compatibility with how other external environments (e.g., TensorFlow, JAX) # represent values of `np.float16`. value = np.asarray(value, np.uint16).view(np.float16).tolist() - elif dtype is ml_dtypes.bfloat16: - value = array_pb.bfloat16_list.value - # Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a - # protobuf field of type `int32` using the following logic in order to - # maintain compatibility with how other external environments (e.g., - # TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`. - value = np.asarray(value, np.uint16).view(ml_dtypes.bfloat16).tolist() elif dtype is np.float32: value = array_pb.float32_list.value elif dtype is np.float64: @@ -95,6 +88,13 @@ def from_proto(array_pb: array_pb2.Array) -> Array: ) value = iter(array_pb.complex128_list.value) value = [complex(real, imag) for real, imag in zip(value, value)] + elif dtype is ml_dtypes.bfloat16: + value = array_pb.bfloat16_list.value + # Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a + # protobuf field of type `int32` using the following logic in order to + # maintain compatibility with how other external environments (e.g., + # TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`. + value = np.asarray(value, np.uint16).view(ml_dtypes.bfloat16).tolist() elif dtype is np.str_: value = array_pb.string_list.value else: diff --git a/tensorflow_federated/python/core/impl/compiler/array_test.py b/tensorflow_federated/python/core/impl/compiler/array_test.py index f11f9b4254..7b3df68990 100644 --- a/tensorflow_federated/python/core/impl/compiler/array_test.py +++ b/tensorflow_federated/python/core/impl/compiler/array_test.py @@ -119,19 +119,6 @@ class FromProtoTest(parameterized.TestCase): ), np.float16(1.0), ), - ( - 'bfloat16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[]), - bfloat16_list=array_pb2.Array.IntList( - value=[ - np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() - ] - ), - ), - ml_dtypes.bfloat16(1.0), - ), ( 'float32', array_pb2.Array( @@ -168,6 +155,19 @@ class FromProtoTest(parameterized.TestCase): ), np.complex128(1.0 + 1.0j), ), + ( + 'bfloat16', + array_pb2.Array( + dtype=data_type_pb2.DataType.DT_BFLOAT16, + shape=array_pb2.ArrayShape(dim=[]), + bfloat16_list=array_pb2.Array.IntList( + value=[ + np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() + ] + ), + ), + ml_dtypes.bfloat16(1.0), + ), ( 'str', array_pb2.Array( @@ -668,23 +668,6 @@ def test_returns_value_with_no_dtype_hint(self, value, expected_value): complex128_list=array_pb2.Array.DoubleList(value=[1.0, 1.0]), ), ), - ( - 'bfloat16', - # Note: we must not use Python `float` here because ml_dtypes.bfloat16 - # is declared as kind `V` (void) not `f` (float) to prevent numpy from - # trying to equate float16 and bfloat16 (which are not compatible). - ml_dtypes.bfloat16(1.0), - ml_dtypes.bfloat16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[]), - bfloat16_list=array_pb2.Array.IntList( - value=[ - np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() - ] - ), - ), - ), ( 'str', 'abc', @@ -725,6 +708,23 @@ def test_returns_value_with_no_dtype_hint(self, value, expected_value): int32_list=array_pb2.Array.IntList(value=[1]), ), ), + ( + 'generic_bfloat16', + # Note: we must not use Python `float` here because ml_dtypes.bfloat16 + # is declared as kind `V` (void) not `f` (float) to prevent numpy from + # trying to equate float16 and bfloat16 (which are not compatible). + ml_dtypes.bfloat16(1.0), + ml_dtypes.bfloat16, + array_pb2.Array( + dtype=data_type_pb2.DataType.DT_BFLOAT16, + shape=array_pb2.ArrayShape(dim=[]), + bfloat16_list=array_pb2.Array.IntList( + value=[ + np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() + ] + ), + ), + ), ( 'generic_str', np.str_('abc'), diff --git a/tensorflow_federated/python/core/impl/types/dtype_utils_test.py b/tensorflow_federated/python/core/impl/types/dtype_utils_test.py index cb03392c58..35c392a795 100644 --- a/tensorflow_federated/python/core/impl/types/dtype_utils_test.py +++ b/tensorflow_federated/python/core/impl/types/dtype_utils_test.py @@ -48,12 +48,12 @@ def test_to_proto_raises_not_implemented_error(self, dtype): ('uint16', np.uint16), ('uint32', np.uint32), ('uint64', np.uint64), - ('bfloat16', ml_dtypes.bfloat16), ('float16', np.float16), ('float32', np.float32), ('float64', np.float64), ('complex64', np.complex64), ('complex128', np.complex128), + ('bfloat16', ml_dtypes.bfloat16), ('str', np.str_), ) def test_is_valid_dtype_returns_true(self, dtype):