diff --git a/tensorflow_federated/python/learning/programs/training_program_logic.py b/tensorflow_federated/python/learning/programs/training_program_logic.py index bd326ce1e6..48b21b6742 100644 --- a/tensorflow_federated/python/learning/programs/training_program_logic.py +++ b/tensorflow_federated/python/learning/programs/training_program_logic.py @@ -24,7 +24,7 @@ import asyncio from collections.abc import Coroutine import datetime -from typing import Optional, Union +from typing import NamedTuple, Optional, Union from absl import logging @@ -38,6 +38,14 @@ from tensorflow_federated.python.program import value_reference +class ProgramState(NamedTuple): + """A structure representing the program state.""" + + state: composers.LearningAlgorithmState + round_number: int + next_evaluation_timestamp_seconds: Optional[int] + + class TaskManager: """A manager for inflight tasks. @@ -160,10 +168,14 @@ async def train_model( if train_state is None: raise ValueError('The initial train state is None.') program_state, version = await program_state_manager.load_latest( - (train_state, 0) + ProgramState(train_state, 0, 0) ) if program_state is not None: - train_state, start_round = program_state + train_state = program_state.state + start_round = program_state.round_number + next_evaluation_timestamp_seconds = ( + program_state.next_evaluation_timestamp_seconds + ) logging.info('Found previous program state version %d', version) if start_round < train_total_rounds: logging.info( @@ -192,8 +204,12 @@ async def train_model( # Ensure the initial state (round 0) is saved before any training occurs. # The program manager `keep_first=True` parameterization will enable users # to start future experiments from the same initialization. + next_evaluation_timestamp_seconds = None await program_state_manager.save( - (train_state, start_round), version=start_round + ProgramState( + train_state, start_round, next_evaluation_timestamp_seconds + ), + version=start_round, ) train_state_type, _ = train_process.next.type_signature.result # pytype: disable=attribute-error @@ -201,7 +217,11 @@ async def train_model( # Track a future time after which an evaluation should be started. This will # be `evaluation_periodicity` after the most recent evaluation time. - next_evaluation_time = None + next_evaluation_time = ( + datetime.datetime.fromtimestamp(next_evaluation_timestamp_seconds) + if next_evaluation_timestamp_seconds + else None + ) def should_evaluate_round( round_num: int, train_round_finished_time: datetime.datetime @@ -251,8 +271,27 @@ def should_evaluate_round( ) train_state = train_result.state train_metrics = train_result.metrics + + train_round_finished_time = datetime.datetime.now() + if evaluation_manager is not None and should_evaluate_round( + round_num, train_round_finished_time + ): + model_weights = train_process.get_model_weights(train_state) + await evaluation_manager.start_evaluation( + round_num, int(train_round_finished_time.timestamp()), model_weights + ) + logging.info('Added evaluation for training round %d', round_num) + + next_evaluation_timestamp_seconds = ( + int(next_evaluation_time.timestamp()) if next_evaluation_time else None + ) task_manager.add_task( - program_state_manager.save((train_state, round_num), version=round_num) + program_state_manager.save( + ProgramState( + train_state, round_num, next_evaluation_timestamp_seconds + ), + version=round_num, + ) ) if train_metrics_manager is not None: try: @@ -276,15 +315,6 @@ def should_evaluate_round( key=round_num, ) ) - train_round_finished_time = datetime.datetime.now() - if evaluation_manager is not None and should_evaluate_round( - round_num, train_round_finished_time - ): - model_weights = train_process.get_model_weights(train_state) - await evaluation_manager.start_evaluation( - round_num, int(train_round_finished_time.timestamp()), model_weights - ) - logging.info('Added evaluation for training round %d', round_num) task_manager.add_task( model_output_manager.release( diff --git a/tensorflow_federated/python/learning/programs/training_program_logic_test.py b/tensorflow_federated/python/learning/programs/training_program_logic_test.py index 70d5fef9bd..0210c90195 100644 --- a/tensorflow_federated/python/learning/programs/training_program_logic_test.py +++ b/tensorflow_federated/python/learning/programs/training_program_logic_test.py @@ -36,6 +36,7 @@ from tensorflow_federated.python.program import release_manager # Convenience aliases. +ProgramState = training_program_logic.ProgramState TensorType = computation_types.TensorType @@ -225,16 +226,21 @@ async def return_round_num() -> int: ) self.assertEqual( mock_program_state_manager.load_latest.call_args_list, - [mock.call((any_algorithm_state, 0))], - ) + [mock.call(ProgramState(any_algorithm_state, 0, 0))], + ) + expected_state_manager_call_list = [] + # Expect saving the initial state (version 0) and training rounds 1 + # through training_rounds. + for round_num in range(0, training_rounds + 1): + expected_state_manager_call_list.append( + mock.call( + ProgramState(any_algorithm_state, round_num, None), + version=round_num, + ) + ) self.assertEqual( mock_program_state_manager.save.call_args_list, - # Expect saving the initial state (version 0) and training rounds 1 - # through training_rounds. - [ - mock.call((any_algorithm_state, round_num), version=round_num) - for round_num in range(0, training_rounds + 1) - ], + expected_state_manager_call_list, ) # Assert that training metrics were released every round. @@ -312,7 +318,7 @@ async def return_round_num() -> None: ) # Patch `datetime.now` so that each round looks like it takes 20 - # milliseconds. With evaluation periodicity of 25 milliseconds and training + # seconds. With evaluation periodicity of 25 seconds and training # rounds finishing (relative) at [0, 20, 40, 60, 80], the test will expect # evaluations at round [1, 3, 5]. with mock.patch( @@ -320,7 +326,7 @@ async def return_round_num() -> None: ) as mock_datetime: start_datetime = datetime.datetime(2022, 11, 17, 9, 0) mock_datetime.now.side_effect = [ - start_datetime + datetime.timedelta(milliseconds=20 * i) + start_datetime + datetime.timedelta(seconds=20 * i) for i in range(training_rounds) ] await training_program_logic.train_model( @@ -332,7 +338,7 @@ async def return_round_num() -> None: model_output_manager=mock_model_output_manager, evaluation_manager=mock_evaluation_manager, train_metrics_manager=mock_train_metrics_manager, - evaluation_periodicity=datetime.timedelta(milliseconds=25), + evaluation_periodicity=datetime.timedelta(seconds=25), ) # Assert that the program attempted to load a previous checkpoint and then @@ -346,16 +352,32 @@ async def return_round_num() -> None: ) self.assertEqual( mock_program_state_manager.load_latest.call_args_list, - [mock.call((any_algorithm_state, 0))], - ) + [mock.call(ProgramState(any_algorithm_state, 0, 0))], + ) + # The next evaluation time of the first round is None. The last round will + # always be evaluated, and the next evaluation time won't be updated. + next_evaluation_timestamps = [None] + for relative_timestamp in [25, 25, 65, 65, 65]: + next_evaluation_timestamps.append( + relative_timestamp + int(start_datetime.timestamp()) + ) + # Expect saving the initial state (version 0) and training rounds 1 + # through training_rounds. + expected_state_manager_call_list = [] + for round_num in range(0, training_rounds + 1): + expected_state_manager_call_list.append( + mock.call( + ProgramState( + any_algorithm_state, + round_num, + next_evaluation_timestamps[round_num], + ), + version=round_num, + ) + ) self.assertEqual( mock_program_state_manager.save.call_args_list, - # Expect saving the initial state (version 0) and training rounds 1 - # through training_rounds. - [ - mock.call((any_algorithm_state, round_num), version=round_num) - for round_num in range(0, training_rounds + 1) - ], + expected_state_manager_call_list, ) # Assert that training metrics were released every round. @@ -432,16 +454,21 @@ async def test_integration_runs_5_training_rounds_no_eval_manager(self): ) self.assertEqual( mock_program_state_manager.load_latest.call_args_list, - [mock.call((any_algorithm_state, 0))], - ) + [mock.call(ProgramState(any_algorithm_state, 0, 0))], + ) + expected_state_manager_call_list = [] + # Expect saving the initial state (version 0) and training rounds 1 + # through training_rounds. + for round_num in range(0, training_rounds + 1): + expected_state_manager_call_list.append( + mock.call( + ProgramState(any_algorithm_state, round_num, None), + version=round_num, + ) + ) self.assertEqual( mock_program_state_manager.save.call_args_list, - # Expect saving the initial state (version 0) and training rounds 1 - # through training_rounds. - [ - mock.call((any_algorithm_state, round_num), version=round_num) - for round_num in range(0, training_rounds + 1) - ], + expected_state_manager_call_list, ) # Assert that training metrics were released every round. @@ -509,11 +536,11 @@ async def test_program_state_manager_work_with_initial_state(self): # given initial state and save it as version 0. self.assertEqual( mock_program_state_manager.load_latest.call_args_list, - [mock.call((initial_train_state, 0))], + [mock.call(ProgramState(initial_train_state, 0, 0))], ) self.assertEqual( mock_program_state_manager.save.call_args_list[0], - mock.call((initial_train_state, 0), version=0), + mock.call(ProgramState(initial_train_state, 0, None), version=0), ) @context_stack_test_utils.with_context(_create_test_context) @@ -529,7 +556,7 @@ async def test_resumes_from_previous_version_10_runs_one_round(self): program_state_manager.ProgramStateManager, instance=True ) mock_program_state_manager.load_latest.side_effect = [( - (training_state, training_rounds - 1), + ProgramState(training_state, training_rounds - 1, None), training_rounds - 1, )] @@ -571,7 +598,7 @@ async def test_resumes_from_previous_version_10_runs_one_round(self): ) self.assertEqual( mock_program_state_manager.load_latest.call_args_list, - [mock.call((any_algorithm_state, 0))], + [mock.call(ProgramState(any_algorithm_state, 0, 0))], ) self.assertEqual( mock_program_state_manager.save.call_args_list, @@ -579,7 +606,8 @@ async def test_resumes_from_previous_version_10_runs_one_round(self): # 10th version) [ mock.call( - (any_algorithm_state, training_rounds), version=training_rounds + ProgramState(any_algorithm_state, training_rounds, None), + version=training_rounds, ) ], msg=mock_program_state_manager.save.call_args_list, @@ -620,7 +648,7 @@ async def test_resumes_from_previous_runs_no_train_rounds(self): program_state_manager.ProgramStateManager, instance=True ) mock_program_state_manager.load_latest.side_effect = [( - (training_state, training_rounds), + ProgramState(training_state, training_rounds, None), training_rounds, )] @@ -660,7 +688,7 @@ async def test_resumes_from_previous_runs_no_train_rounds(self): ) self.assertSequenceEqual( mock_program_state_manager.load_latest.call_args_list, - [mock.call((any_algorithm_state, 0))], + [mock.call(ProgramState(any_algorithm_state, 0, 0))], ) mock_program_state_manager.save.assert_not_called()