From c18d9c2d0f09e3a1baf91bcf0bee0cced08090ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 2 May 2024 12:15:58 -0700 Subject: [PATCH] Implement Serialize and Deserialize methods for CheckpointAggregator. PiperOrigin-RevId: 630141071 --- .../cc/core/impl/aggregation/protocol/BUILD | 11 ++ .../protocol/checkpoint_aggregator.cc | 78 +++++++++- .../protocol/checkpoint_aggregator.h | 31 +++- .../protocol/checkpoint_aggregator.proto | 8 + .../protocol/checkpoint_aggregator_test.cc | 138 ++++++++++++++++-- 5 files changed, 245 insertions(+), 21 deletions(-) create mode 100644 tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.proto diff --git a/tensorflow_federated/cc/core/impl/aggregation/protocol/BUILD b/tensorflow_federated/cc/core/impl/aggregation/protocol/BUILD index 0abb3d78bb..522de8d724 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/protocol/BUILD +++ b/tensorflow_federated/cc/core/impl/aggregation/protocol/BUILD @@ -41,6 +41,16 @@ proto_library( deps = ["//tensorflow_federated/cc/core/impl/aggregation/core:tensor_proto"], ) +proto_library( + name = "checkpoint_aggregator_proto", + srcs = ["checkpoint_aggregator.proto"], +) + +cc_proto_library( + name = "checkpoint_aggregator_cc_proto", + deps = [":checkpoint_aggregator_proto"], +) + cc_proto_library( name = "configuration_cc_proto", visibility = ["//visibility:public"], @@ -99,6 +109,7 @@ cc_library( hdrs = ["checkpoint_aggregator.h"], visibility = ["//visibility:public"], deps = [ + ":checkpoint_aggregator_cc_proto", ":checkpoint_builder", ":checkpoint_parser", ":config_converter", diff --git a/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.cc b/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.cc index 74c0fb2141..a3ffa52bc6 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.cc +++ b/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.cc @@ -36,6 +36,7 @@ #include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_aggregator_factory.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_aggregator_registry.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_spec.h" +#include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.pb.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_builder.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_parser.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/config_converter.h" @@ -67,13 +68,50 @@ absl::Status CheckpointAggregator::ValidateConfig( absl::StatusOr> CheckpointAggregator::Create(const Configuration& configuration) { + return CreateInternal(configuration, nullptr); +} + +absl::StatusOr> +CheckpointAggregator::Create(const std::vector* intrinsics) { + return CreateInternal(intrinsics, nullptr); +} + +absl::StatusOr> +CheckpointAggregator::Deserialize(const Configuration& configuration, + std::string serialized_state) { + CheckpointAggregatorState aggregator_state; + if (!aggregator_state.ParseFromString(serialized_state)) { + return absl::InvalidArgumentError("Failed to parse serialized state."); + } + return CreateInternal(configuration, &aggregator_state); +} + +absl::StatusOr> +CheckpointAggregator::Deserialize(const std::vector* intrinsics, + std::string serialized_state) { + CheckpointAggregatorState aggregator_state; + if (!aggregator_state.ParseFromString(serialized_state)) { + return absl::InvalidArgumentError("Failed to parse serialized state."); + } + return CreateInternal(intrinsics, &aggregator_state); +} + +absl::StatusOr> +CheckpointAggregator::CreateInternal( + const Configuration& configuration, + const CheckpointAggregatorState* aggregator_state) { TFF_ASSIGN_OR_RETURN(std::vector intrinsics, ParseFromConfig(configuration)); std::vector> aggregators; - for (const Intrinsic& intrinsic : intrinsics) { + for (int i = 0; i < intrinsics.size(); ++i) { + const Intrinsic& intrinsic = intrinsics[i]; + const std::string* serialized_aggregator = nullptr; + if (aggregator_state != nullptr) { + serialized_aggregator = &aggregator_state->aggregators(i); + } TFF_ASSIGN_OR_RETURN(std::unique_ptr aggregator, - CreateAggregator(intrinsic)); + CreateAggregator(intrinsic, serialized_aggregator)); aggregators.push_back(std::move(aggregator)); } @@ -82,11 +120,18 @@ CheckpointAggregator::Create(const Configuration& configuration) { } absl::StatusOr> -CheckpointAggregator::Create(const std::vector* intrinsics) { +CheckpointAggregator::CreateInternal( + const std::vector* intrinsics, + const CheckpointAggregatorState* aggregator_state) { std::vector> aggregators; - for (const Intrinsic& intrinsic : *intrinsics) { + for (int i = 0; i < intrinsics->size(); ++i) { + const Intrinsic& intrinsic = (*intrinsics)[i]; + const std::string* serialized_aggregator = nullptr; + if (aggregator_state != nullptr) { + serialized_aggregator = &aggregator_state->aggregators(i); + } TFF_ASSIGN_OR_RETURN(std::unique_ptr aggregator, - CreateAggregator(intrinsic)); + CreateAggregator(intrinsic, serialized_aggregator)); aggregators.push_back(std::move(aggregator)); } @@ -201,14 +246,33 @@ absl::Status CheckpointAggregator::Report( void CheckpointAggregator::Abort() { aggregation_finished_ = true; } +absl::StatusOr CheckpointAggregator::Serialize() && { + absl::MutexLock lock(&aggregation_mu_); + if (aggregation_finished_) { + return absl::AbortedError("Aggregation has already been finished."); + } + CheckpointAggregatorState state; + google::protobuf::RepeatedPtrField* aggregators_proto = + state.mutable_aggregators(); + aggregators_proto->Reserve(aggregators_.size()); + for (const auto& aggregator : aggregators_) { + aggregators_proto->Add(std::move(*aggregator).Serialize().value()); + } + return state.SerializeAsString(); +} + absl::StatusOr> -CheckpointAggregator::CreateAggregator(const Intrinsic& intrinsic) { +CheckpointAggregator::CreateAggregator( + const Intrinsic& intrinsic, const std::string* serialized_aggregator) { // Resolve the intrinsic_uri to the registered TensorAggregatorFactory. TFF_ASSIGN_OR_RETURN(const TensorAggregatorFactory* factory, GetAggregatorFactory(intrinsic.uri)); // Use the factory to create the TensorAggregator instance. - return factory->Create(intrinsic); + if (serialized_aggregator == nullptr) { + return factory->Create(intrinsic); + } + return factory->Deserialize(intrinsic, *serialized_aggregator); } std::vector> diff --git a/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.h b/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.h index 3df642767e..cb3c042a5b 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.h +++ b/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "absl/base/attributes.h" @@ -31,6 +32,7 @@ #include "absl/synchronization/mutex.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/intrinsic.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_aggregator.h" +#include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.pb.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_builder.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_parser.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/configuration.pb.h" @@ -62,6 +64,20 @@ class CheckpointAggregator { static absl::StatusOr> Create( const std::vector* intrinsics ABSL_ATTRIBUTE_LIFETIME_BOUND); + // Creates an instance of CheckpointAggregator based on the given + // configuration and serialized state. + static absl::StatusOr> Deserialize( + const Configuration& configuration, std::string serialized_state); + + // Creates an instance of CheckpointAggregator based on the given intrinsics + // and serialized state. + // The `intrinsics` are expected to be created using `ParseFromConfig` which + // validates the configuration. CheckpointAggregator does not take any + // ownership, and `intrinsics` must outlive it. + static absl::StatusOr> Deserialize( + const std::vector* intrinsics ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::string serialized_state); + // Accumulates a checkpoint via nested tensor aggregators. The tensors are // provided by the CheckpointParser instance. absl::Status Accumulate(CheckpointParser& checkpoint_parser); @@ -75,6 +91,8 @@ class CheckpointAggregator { // Signal that the aggregation must be aborted and the report can't be // produced. void Abort(); + // Serialize the internal state of the checkpoint aggregator as a string. + absl::StatusOr Serialize() &&; private: CheckpointAggregator( @@ -85,9 +103,18 @@ class CheckpointAggregator { std::vector intrinsics, std::vector> aggregators); - // Creates an aggregation intrinsic based on the intrinsic configuration. + // Creates an aggregation intrinsic based on the intrinsic configuration and + // optional serialized state. static absl::StatusOr> CreateAggregator( - const Intrinsic& intrinsic); + const Intrinsic& intrinsic, const std::string* serialized_aggregator); + + static absl::StatusOr> CreateInternal( + const Configuration& configuration, + const CheckpointAggregatorState* aggregator_state); + + static absl::StatusOr> CreateInternal( + const std::vector* intrinsics, + const CheckpointAggregatorState* aggregator_state); // Used by the implementation of Merge. std::vector> TakeAggregators() &&; diff --git a/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.proto b/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.proto new file mode 100644 index 0000000000..fe0ea2e2f4 --- /dev/null +++ b/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +package tensorflow_federated.aggregation; + +// Internal state representation of a CheckpointAggregator. +message CheckpointAggregatorState { + repeated bytes aggregators = 1; +} diff --git a/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator_test.cc b/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator_test.cc index 50bda1ae38..fac2cee94a 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator_test.cc +++ b/tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator_test.cc @@ -60,6 +60,9 @@ using ::testing::ByMove; using ::testing::Invoke; using ::testing::Return; using ::testing::StrEq; +using testing::TestWithParam; + +using CheckpointAggregatorTest = TestWithParam; Configuration default_configuration() { // One "federated_sum" intrinsic with a single scalar int32 tensor. @@ -401,19 +404,26 @@ TEST(CheckpointAggregatorTest, CreateMismatchingInputAndOutputShape) { StatusIs(INVALID_ARGUMENT)); } -TEST(CheckpointAggregatorTest, CreateFromIntrinsicsAccumulateSuccess) { +TEST_P(CheckpointAggregatorTest, CreateFromIntrinsicsAccumulateSuccess) { std::vector intrinsics; intrinsics.push_back({"federated_sum", {TensorSpec("foo", DT_INT32, {})}, {TensorSpec("foo_out", DT_INT32, {})}, {}, {}}); - auto aggregator = CheckpointAggregator::Create(&intrinsics); + auto aggregator = CheckpointAggregator::Create(&intrinsics).value(); + + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = + CheckpointAggregator::Deserialize(&intrinsics, serialized_state) + .value(); + } MockCheckpointParser parser; EXPECT_CALL(parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] { return Tensor::Create(DT_INT32, {}, CreateTestData({2})); })); - EXPECT_OK((*aggregator)->Accumulate(parser)); + TFF_EXPECT_OK(aggregator->Accumulate(parser)); } TEST(CheckpointAggregatorTest, AccumulateMissingTensor) { @@ -467,16 +477,23 @@ TEST(CheckpointAggregatorTest, AccumulateAfterAbort) { EXPECT_THAT(aggregator->Accumulate(parser), StatusIs(ABORTED)); } -TEST(CheckpointAggregatorTest, ReportZeroInputs) { +TEST_P(CheckpointAggregatorTest, ReportZeroInputs) { auto aggregator = CreateWithDefaultConfig(); + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = CheckpointAggregator::Deserialize(default_configuration(), + serialized_state) + .value(); + } + MockCheckpointBuilder builder; EXPECT_CALL(builder, Add(StrEq("foo_out"), IsTensor({}, {0}))) .WillOnce(Return(absl::OkStatus())); EXPECT_OK(aggregator->Report(builder)); } -TEST(CheckpointAggregatorTest, ReportOneInput) { +TEST_P(CheckpointAggregatorTest, ReportOneInput) { auto aggregator = CreateWithDefaultConfig(); MockCheckpointParser parser; EXPECT_CALL(parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] { @@ -484,13 +501,20 @@ TEST(CheckpointAggregatorTest, ReportOneInput) { })); EXPECT_OK(aggregator->Accumulate(parser)); + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = CheckpointAggregator::Deserialize(default_configuration(), + serialized_state) + .value(); + } + MockCheckpointBuilder builder; EXPECT_CALL(builder, Add(StrEq("foo_out"), IsTensor({}, {2}))) .WillOnce(Return(absl::OkStatus())); EXPECT_OK(aggregator->Report(builder)); } -TEST(CheckpointAggregatorTest, ReportTwoInputs) { +TEST_P(CheckpointAggregatorTest, ReportTwoInputs) { auto aggregator = CreateWithDefaultConfig(); MockCheckpointParser parser; EXPECT_CALL(parser, GetTensor(StrEq("foo"))) @@ -501,13 +525,20 @@ TEST(CheckpointAggregatorTest, ReportTwoInputs) { EXPECT_OK(aggregator->Accumulate(parser)); EXPECT_OK(aggregator->Accumulate(parser)); + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = CheckpointAggregator::Deserialize(default_configuration(), + serialized_state) + .value(); + } + MockCheckpointBuilder builder; EXPECT_CALL(builder, Add(StrEq("foo_out"), IsTensor({}, {5}))) .WillOnce(Return(absl::OkStatus())); EXPECT_OK(aggregator->Report(builder)); } -TEST(CheckpointAggregatorTest, ReportMultipleTensors) { +TEST_P(CheckpointAggregatorTest, ReportMultipleTensors) { Configuration config_message = PARSE_TEXT_PROTO(R"pb( intrinsic_configs { intrinsic_uri: "federated_sum" @@ -550,6 +581,14 @@ TEST(CheckpointAggregatorTest, ReportMultipleTensors) { CreateTestData({1.f, 2.f, 3.f, 4.f})); })); EXPECT_OK(aggregator->Accumulate(parser)); + + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = + CheckpointAggregator::Deserialize(config_message, serialized_state) + .value(); + } + MockCheckpointBuilder builder; EXPECT_CALL(builder, Add(StrEq("foo_out"), IsTensor({3}, {1, 2, 3}))) .WillOnce(Return(absl::OkStatus())); @@ -631,7 +670,7 @@ TEST(CheckpointAggregatorTest, ReportWithFailedCanReportPrecondition) { EXPECT_THAT(aggregator->Report(builder), StatusIs(FAILED_PRECONDITION)); } -TEST(CheckpointAggregatorTest, ReportFedSqlZeroInputs) { +TEST_P(CheckpointAggregatorTest, ReportFedSqlZeroInputs) { // One intrinsic: // fedsql_group_by with two grouping keys key1 and key2, only the first one // of which should be output, and two inner GoogleSQL:sum intrinsics bar @@ -697,6 +736,13 @@ TEST(CheckpointAggregatorTest, ReportFedSqlZeroInputs) { )pb"); auto aggregator = Create(config_message); + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = + CheckpointAggregator::Deserialize(config_message, serialized_state) + .value(); + } + MockCheckpointBuilder builder; // Verify that empty tensors are added to the result checkpoint. EXPECT_CALL(builder, Add(StrEq("key1_out"), IsTensor({0}, {}))) @@ -708,7 +754,7 @@ TEST(CheckpointAggregatorTest, ReportFedSqlZeroInputs) { EXPECT_OK(aggregator->Report(builder)); } -TEST(CheckpointAggregatorTest, ReportFedSqlsOneInput) { +TEST_P(CheckpointAggregatorTest, ReportFedSqlsOneInput) { auto aggregator = CreateWithDefaultFedSqlConfig(); MockCheckpointParser parser; @@ -720,6 +766,13 @@ TEST(CheckpointAggregatorTest, ReportFedSqlsOneInput) { })); EXPECT_OK(aggregator->Accumulate(parser)); + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = CheckpointAggregator::Deserialize( + default_fedsql_configuration(), serialized_state) + .value(); + } + MockCheckpointBuilder builder; EXPECT_CALL(builder, Add(StrEq("key1_out"), IsTensor({2}, {1.f, 2.f}))) .WillOnce(Return(absl::OkStatus())); @@ -728,7 +781,7 @@ TEST(CheckpointAggregatorTest, ReportFedSqlsOneInput) { EXPECT_OK(aggregator->Report(builder)); } -TEST(CheckpointAggregatorTest, ReportFedSqlsTwoInputs) { +TEST_P(CheckpointAggregatorTest, ReportFedSqlsTwoInputs) { auto aggregator = CreateWithDefaultFedSqlConfig(); MockCheckpointParser parser1; @@ -740,6 +793,13 @@ TEST(CheckpointAggregatorTest, ReportFedSqlsTwoInputs) { })); EXPECT_OK(aggregator->Accumulate(parser1)); + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = CheckpointAggregator::Deserialize( + default_fedsql_configuration(), serialized_state) + .value(); + } + MockCheckpointParser parser2; EXPECT_CALL(parser2, GetTensor(StrEq("key1"))).WillOnce(Invoke([] { return Tensor::Create(DT_FLOAT, {2}, CreateTestData({2.f, 4.f})); @@ -759,7 +819,7 @@ TEST(CheckpointAggregatorTest, ReportFedSqlsTwoInputs) { EXPECT_OK(aggregator->Report(builder)); } -TEST(CheckpointAggregatorTest, ReportFedSqlsEmptyInput) { +TEST_P(CheckpointAggregatorTest, ReportFedSqlsEmptyInput) { auto aggregator = CreateWithDefaultFedSqlConfig(); MockCheckpointParser parser; @@ -771,6 +831,13 @@ TEST(CheckpointAggregatorTest, ReportFedSqlsEmptyInput) { })); EXPECT_OK(aggregator->Accumulate(parser)); + if (GetParam()) { + auto serialized_state = std::move(*aggregator).Serialize().value(); + aggregator = CheckpointAggregator::Deserialize( + default_fedsql_configuration(), serialized_state) + .value(); + } + MockCheckpointBuilder builder; EXPECT_CALL(builder, Add(StrEq("key1_out"), IsTensor({0}, {}))) .WillOnce(Return(absl::OkStatus())); @@ -779,7 +846,7 @@ TEST(CheckpointAggregatorTest, ReportFedSqlsEmptyInput) { EXPECT_OK(aggregator->Report(builder)); } -TEST(CheckpointAggregatorTest, MergeSuccess) { +TEST_P(CheckpointAggregatorTest, MergeSuccess) { auto aggregator1 = CreateWithDefaultConfig(); MockCheckpointParser parser1; EXPECT_CALL(parser1, GetTensor(StrEq("foo"))).WillOnce(Invoke([] { @@ -794,8 +861,26 @@ TEST(CheckpointAggregatorTest, MergeSuccess) { })); EXPECT_OK(aggregator2->Accumulate(parser2)); + if (GetParam()) { + auto serialized_state1 = std::move(*aggregator1).Serialize().value(); + aggregator1 = CheckpointAggregator::Deserialize(default_configuration(), + serialized_state1) + .value(); + auto serialized_state2 = std::move(*aggregator2).Serialize().value(); + aggregator2 = CheckpointAggregator::Deserialize(default_configuration(), + serialized_state2) + .value(); + } + EXPECT_OK(aggregator1->MergeWith(std::move(*aggregator2))); + if (GetParam()) { + auto serialized_state1 = std::move(*aggregator1).Serialize().value(); + aggregator1 = CheckpointAggregator::Deserialize(default_configuration(), + serialized_state1) + .value(); + } + MockCheckpointBuilder builder; EXPECT_CALL(builder, Add(StrEq("foo_out"), IsTensor({}, {12}))) .WillOnce(Return(absl::OkStatus())); @@ -930,6 +1015,35 @@ TEST(CheckpointAggregatorTest, ConcurrentAccumulationAbortWhileQueued) { scheduler->WaitUntilIdle(); } +TEST(CheckpointAggregatorTest, SerializeAfterReport) { + auto aggregator = CreateWithDefaultConfig(); + + MockCheckpointBuilder builder; + EXPECT_CALL(builder, Add(StrEq("foo_out"), IsTensor({}, {0}))) + .WillOnce(Return(absl::OkStatus())); + TFF_EXPECT_OK(aggregator->Report(builder)); + EXPECT_THAT(std::move(*aggregator).Serialize(), StatusIs(ABORTED)); +} + +TEST(CheckpointAggregatorTest, SerializeAfterAbort) { + auto aggregator = CreateWithDefaultConfig(); + aggregator->Abort(); + EXPECT_THAT(std::move(*aggregator).Serialize(), StatusIs(ABORTED)); +} + +TEST(CheckpointAggregatorTest, DeserializeInvalidState) { + std::string serialized_state = "invalid"; + EXPECT_THAT(CheckpointAggregator::Deserialize(default_configuration(), + serialized_state), + StatusIs(INVALID_ARGUMENT)); +} + +INSTANTIATE_TEST_SUITE_P( + CheckpointAggregatorTestInstantiation, CheckpointAggregatorTest, + testing::ValuesIn({false, true}), + [](const testing::TestParamInfo& + info) { return info.param ? "SerializeDeserialize" : "None"; }); + } // namespace } // namespace aggregation } // namespace tensorflow_federated