Skip to content

Commit

Permalink
Sort dtype switch statements consistently.
Browse files Browse the repository at this point in the history
This change is a follow up to cl/662652837 and should reduce merge conflicts with upcoming changes to serialization (cl/658906580) and make maintaining this code easier.

PiperOrigin-RevId: 662966562
  • Loading branch information
michaelreneer authored and copybara-github committed Aug 14, 2024
1 parent f16df18 commit c860c83
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 110 deletions.
42 changes: 21 additions & 21 deletions tensorflow_federated/cc/core/impl/executors/array_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,21 +124,24 @@ inline absl::StatusOr<v0::Array> CreateArray(
return array_pb;
}

// Overload for Eigen::bfloat16.
// Overload for complex.
template <typename T>
inline absl::StatusOr<v0::Array> CreateArray(
v0::DataType dtype, v0::ArrayShape shape_pb,
std::initializer_list<const Eigen::bfloat16> values) {
std::initializer_list<std::complex<T>> values) {
v0::Array array_pb;
array_pb.set_dtype(dtype);
array_pb.mutable_shape()->Swap(&shape_pb);
const T* begin = reinterpret_cast<const T*>(values.begin());
switch (dtype) {
case v0::DataType::DT_BFLOAT16: {
auto size = values.size();
array_pb.mutable_bfloat16_list()->mutable_value()->Reserve(size);
for (auto element : values) {
array_pb.mutable_bfloat16_list()->mutable_value()->AddAlreadyReserved(
Eigen::numext::bit_cast<uint16_t>(element));
}
case v0::DataType::DT_COMPLEX64: {
array_pb.mutable_complex64_list()->mutable_value()->Assign(
begin, begin + values.size() * 2);
break;
}
case v0::DataType::DT_COMPLEX128: {
array_pb.mutable_complex128_list()->mutable_value()->Assign(
begin, begin + values.size() * 2);
break;
}
default:
Expand All @@ -148,24 +151,21 @@ inline absl::StatusOr<v0::Array> CreateArray(
return array_pb;
}

// Overload for complex.
template <typename T>
// Overload for Eigen::bfloat16.
inline absl::StatusOr<v0::Array> CreateArray(
v0::DataType dtype, v0::ArrayShape shape_pb,
std::initializer_list<std::complex<T>> values) {
std::initializer_list<const Eigen::bfloat16> values) {
v0::Array array_pb;
array_pb.set_dtype(dtype);
array_pb.mutable_shape()->Swap(&shape_pb);
const T* begin = reinterpret_cast<const T*>(values.begin());
switch (dtype) {
case v0::DataType::DT_COMPLEX64: {
array_pb.mutable_complex64_list()->mutable_value()->Assign(
begin, begin + values.size() * 2);
break;
}
case v0::DataType::DT_COMPLEX128: {
array_pb.mutable_complex128_list()->mutable_value()->Assign(
begin, begin + values.size() * 2);
case v0::DataType::DT_BFLOAT16: {
auto size = values.size();
array_pb.mutable_bfloat16_list()->mutable_value()->Reserve(size);
for (auto element : values) {
array_pb.mutable_bfloat16_list()->mutable_value()->AddAlreadyReserved(
Eigen::numext::bit_cast<uint16_t>(element));
}
break;
}
default:
Expand Down
31 changes: 16 additions & 15 deletions tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>
return Eigen::numext::bit_cast<Eigen::half>(static_cast<uint16_t>(x));
});
}

// Overload for complex.
template <typename T>
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<T>& src,
std::complex<T>* dest) {
std::copy(src.begin(), src.end(), reinterpret_cast<T*>(dest));
}

// Overload for Eigen::bfloat16.
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>& src,
Eigen::bfloat16* dest) {
Expand All @@ -101,13 +109,6 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>
});
}

// Overload for complex.
template <typename T>
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<T>& src,
std::complex<T>* dest) {
std::copy(src.begin(), src.end(), reinterpret_cast<T*>(dest));
}

// Overload for string.
static void CopyFromRepeatedField(
const google::protobuf::RepeatedPtrField<std::string>& src,
Expand Down Expand Up @@ -197,14 +198,6 @@ absl::StatusOr<tensorflow::Tensor> TensorFromArray(const v0::Array& array_pb) {
tensor.flat<Eigen::half>().data());
return tensor;
}
case v0::Array::kBfloat16List: {
tensorflow::Tensor tensor(
tensorflow::DataTypeToEnum<Eigen::bfloat16>::value,
TFF_TRY(TensorShapeFromArrayShape(array_pb.shape())));
CopyFromRepeatedField(array_pb.bfloat16_list().value(),
tensor.flat<Eigen::bfloat16>().data());
return tensor;
}
case v0::Array::kFloat32List: {
tensorflow::Tensor tensor(
tensorflow::DataTypeToEnum<float>::value,
Expand Down Expand Up @@ -237,6 +230,14 @@ absl::StatusOr<tensorflow::Tensor> TensorFromArray(const v0::Array& array_pb) {
tensor.flat<tensorflow::complex128>().data());
return tensor;
}
case v0::Array::kBfloat16List: {
tensorflow::Tensor tensor(
tensorflow::DataTypeToEnum<Eigen::bfloat16>::value,
TFF_TRY(TensorShapeFromArrayShape(array_pb.shape())));
CopyFromRepeatedField(array_pb.bfloat16_list().value(),
tensor.flat<Eigen::bfloat16>().data());
return tensor;
}
case v0::Array::kStringList: {
tensorflow::Tensor tensor(
tensorflow::DataTypeToEnum<tensorflow::tstring>::value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,6 @@ INSTANTIATE_TEST_SUITE_P(
.value(),
tensorflow::test::AsScalar(Eigen::half{1.0}),
},
{
"bfloat16",
testing::CreateArray(v0::DataType::DT_BFLOAT16,
testing::CreateArrayShape({}),
{Eigen::bfloat16{1.0}})
.value(),
tensorflow::test::AsScalar(Eigen::bfloat16{1.0}),
},
{
"float32",
testing::CreateArray(v0::DataType::DT_FLOAT,
Expand Down Expand Up @@ -252,6 +244,14 @@ INSTANTIATE_TEST_SUITE_P(
.value(),
tensorflow::test::AsScalar(tensorflow::complex128{1.0, 1.0}),
},
{
"bfloat16",
testing::CreateArray(v0::DataType::DT_BFLOAT16,
testing::CreateArrayShape({}),
{Eigen::bfloat16{1.0}})
.value(),
tensorflow::test::AsScalar(Eigen::bfloat16{1.0}),
},
{
"string",
testing::CreateArray(v0::DataType::DT_STRING,
Expand Down
30 changes: 15 additions & 15 deletions tensorflow_federated/cc/core/impl/executors/xla_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,14 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>
});
}

// Overload for Eigen::bflot16.
// Overload for complex.
template <typename T>
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<T>& src,
std::complex<T>* dest) {
std::copy(src.begin(), src.end(), reinterpret_cast<T*>(dest));
}

// Overload for Eigen::bfloat16.
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>& src,
Eigen::bfloat16* dest) {
// Values of dtype ml_dtypes.bfloat16 are packed to and unpacked from a
Expand All @@ -133,13 +140,6 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>
});
}

// Overload for complex.
template <typename T>
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<T>& src,
std::complex<T>* dest) {
std::copy(src.begin(), src.end(), reinterpret_cast<T*>(dest));
}

absl::StatusOr<xla::Literal> LiteralFromArray(const v0::Array& array_pb) {
switch (array_pb.kind_case()) {
case v0::Array::kBoolList: {
Expand Down Expand Up @@ -212,13 +212,6 @@ absl::StatusOr<xla::Literal> LiteralFromArray(const v0::Array& array_pb) {
literal.data<xla::half>().begin());
return literal;
}
case v0::Array::kBfloat16List: {
xla::Literal literal(TFF_TRY(
ShapeFromArrayShape(v0::DataType::DT_BFLOAT16, array_pb.shape())));
CopyFromRepeatedField(array_pb.bfloat16_list().value(),
literal.data<xla::bfloat16>().begin());
return literal;
}
case v0::Array::kFloat32List: {
xla::Literal literal(TFF_TRY(
ShapeFromArrayShape(v0::DataType::DT_FLOAT, array_pb.shape())));
Expand Down Expand Up @@ -247,6 +240,13 @@ absl::StatusOr<xla::Literal> LiteralFromArray(const v0::Array& array_pb) {
literal.data<xla::complex128>().begin());
return literal;
}
case v0::Array::kBfloat16List: {
xla::Literal literal(TFF_TRY(
ShapeFromArrayShape(v0::DataType::DT_BFLOAT16, array_pb.shape())));
CopyFromRepeatedField(array_pb.bfloat16_list().value(),
literal.data<xla::bfloat16>().begin());
return literal;
}
default:
return absl::UnimplementedError(
absl::StrCat("Unexpected DataType found:", array_pb.kind_case()));
Expand Down
26 changes: 13 additions & 13 deletions tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,19 +239,6 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_float16) {
EXPECT_EQ(actual_literal, expected_literal);
}

TEST(LiteralFromArrayTest, TestReturnsLiteral_bfloat16) {
const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray(
v0::DataType::DT_BFLOAT16, testing::CreateArrayShape({}),
{Eigen::bfloat16{1.0}}));

const xla::Literal& actual_literal =
TFF_ASSERT_OK(LiteralFromArray(array_pb));

xla::Literal expected_literal =
xla::LiteralUtil::CreateR0(Eigen::bfloat16{1.0});
EXPECT_EQ(actual_literal, expected_literal);
}

TEST(LiteralFromArrayTest, TestReturnsLiteral_float32) {
const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray(
v0::DataType::DT_FLOAT, testing::CreateArrayShape({}), {1.0}));
Expand Down Expand Up @@ -300,6 +287,19 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_complex128) {
EXPECT_EQ(actual_literal, expected_literal);
}

TEST(LiteralFromArrayTest, TestReturnsLiteral_bfloat16) {
const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray(
v0::DataType::DT_BFLOAT16, testing::CreateArrayShape({}),
{Eigen::bfloat16{1.0}}));

const xla::Literal& actual_literal =
TFF_ASSERT_OK(LiteralFromArray(array_pb));

xla::Literal expected_literal =
xla::LiteralUtil::CreateR0(Eigen::bfloat16{1.0});
EXPECT_EQ(actual_literal, expected_literal);
}

TEST(LiteralFromArrayTest, TestReturnsLiteral_array) {
const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray(
v0::DataType::DT_INT32, testing::CreateArrayShape({2, 3}),
Expand Down
14 changes: 7 additions & 7 deletions tensorflow_federated/python/core/impl/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ def from_proto(array_pb: array_pb2.Array) -> Array:
# compatibility with how other external environments (e.g., TensorFlow, JAX)
# represent values of `np.float16`.
value = np.asarray(value, np.uint16).view(np.float16).tolist()
elif dtype is ml_dtypes.bfloat16:
value = array_pb.bfloat16_list.value
# Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a
# protobuf field of type `int32` using the following logic in order to
# maintain compatibility with how other external environments (e.g.,
# TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`.
value = np.asarray(value, np.uint16).view(ml_dtypes.bfloat16).tolist()
elif dtype is np.float32:
value = array_pb.float32_list.value
elif dtype is np.float64:
Expand All @@ -95,6 +88,13 @@ def from_proto(array_pb: array_pb2.Array) -> Array:
)
value = iter(array_pb.complex128_list.value)
value = [complex(real, imag) for real, imag in zip(value, value)]
elif dtype is ml_dtypes.bfloat16:
value = array_pb.bfloat16_list.value
# Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a
# protobuf field of type `int32` using the following logic in order to
# maintain compatibility with how other external environments (e.g.,
# TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`.
value = np.asarray(value, np.uint16).view(ml_dtypes.bfloat16).tolist()
elif dtype is np.str_:
value = array_pb.string_list.value
else:
Expand Down
60 changes: 30 additions & 30 deletions tensorflow_federated/python/core/impl/compiler/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,6 @@ class FromProtoTest(parameterized.TestCase):
),
np.float16(1.0),
),
(
'bfloat16',
array_pb2.Array(
dtype=data_type_pb2.DataType.DT_BFLOAT16,
shape=array_pb2.ArrayShape(dim=[]),
bfloat16_list=array_pb2.Array.IntList(
value=[
np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item()
]
),
),
ml_dtypes.bfloat16(1.0),
),
(
'float32',
array_pb2.Array(
Expand Down Expand Up @@ -168,6 +155,19 @@ class FromProtoTest(parameterized.TestCase):
),
np.complex128(1.0 + 1.0j),
),
(
'bfloat16',
array_pb2.Array(
dtype=data_type_pb2.DataType.DT_BFLOAT16,
shape=array_pb2.ArrayShape(dim=[]),
bfloat16_list=array_pb2.Array.IntList(
value=[
np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item()
]
),
),
ml_dtypes.bfloat16(1.0),
),
(
'str',
array_pb2.Array(
Expand Down Expand Up @@ -668,23 +668,6 @@ def test_returns_value_with_no_dtype_hint(self, value, expected_value):
complex128_list=array_pb2.Array.DoubleList(value=[1.0, 1.0]),
),
),
(
'bfloat16',
# Note: we must not use Python `float` here because ml_dtypes.bfloat16
# is declared as kind `V` (void) not `f` (float) to prevent numpy from
# trying to equate float16 and bfloat16 (which are not compatible).
ml_dtypes.bfloat16(1.0),
ml_dtypes.bfloat16,
array_pb2.Array(
dtype=data_type_pb2.DataType.DT_BFLOAT16,
shape=array_pb2.ArrayShape(dim=[]),
bfloat16_list=array_pb2.Array.IntList(
value=[
np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item()
]
),
),
),
(
'str',
'abc',
Expand Down Expand Up @@ -725,6 +708,23 @@ def test_returns_value_with_no_dtype_hint(self, value, expected_value):
int32_list=array_pb2.Array.IntList(value=[1]),
),
),
(
'generic_bfloat16',
# Note: we must not use Python `float` here because ml_dtypes.bfloat16
# is declared as kind `V` (void) not `f` (float) to prevent numpy from
# trying to equate float16 and bfloat16 (which are not compatible).
ml_dtypes.bfloat16(1.0),
ml_dtypes.bfloat16,
array_pb2.Array(
dtype=data_type_pb2.DataType.DT_BFLOAT16,
shape=array_pb2.ArrayShape(dim=[]),
bfloat16_list=array_pb2.Array.IntList(
value=[
np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item()
]
),
),
),
(
'generic_str',
np.str_('abc'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def test_to_proto_raises_not_implemented_error(self, dtype):
('uint16', np.uint16),
('uint32', np.uint32),
('uint64', np.uint64),
('bfloat16', ml_dtypes.bfloat16),
('float16', np.float16),
('float32', np.float32),
('float64', np.float64),
('complex64', np.complex64),
('complex128', np.complex128),
('bfloat16', ml_dtypes.bfloat16),
('str', np.str_),
)
def test_is_valid_dtype_returns_true(self, dtype):
Expand Down

0 comments on commit c860c83

Please sign in to comment.