Skip to content

Commit

Permalink
Save the next evaluation time in the program state.
Browse files Browse the repository at this point in the history
So the next evaluation time can be loaded when the program restarts.

PiperOrigin-RevId: 564454622
  • Loading branch information
xiaoyux11 authored and tensorflow-copybara committed Sep 13, 2023
1 parent 4a28765 commit e75b511
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import asyncio
from collections.abc import Coroutine
import datetime
from typing import Optional, Union
from typing import NamedTuple, Optional, Union

from absl import logging

Expand All @@ -38,6 +38,14 @@
from tensorflow_federated.python.program import value_reference


class ProgramState(NamedTuple):
"""A structure representing the program state."""

state: composers.LearningAlgorithmState
round_number: int
next_evaluation_timestamp_seconds: Optional[int]


class TaskManager:
"""A manager for inflight tasks.
Expand Down Expand Up @@ -160,10 +168,14 @@ async def train_model(
if train_state is None:
raise ValueError('The initial train state is None.')
program_state, version = await program_state_manager.load_latest(
(train_state, 0)
ProgramState(train_state, 0, 0)
)
if program_state is not None:
train_state, start_round = program_state
train_state = program_state.state
start_round = program_state.round_number
next_evaluation_timestamp_seconds = (
program_state.next_evaluation_timestamp_seconds
)
logging.info('Found previous program state version %d', version)
if start_round < train_total_rounds:
logging.info(
Expand Down Expand Up @@ -192,16 +204,24 @@ async def train_model(
# Ensure the initial state (round 0) is saved before any training occurs.
# The program manager `keep_first=True` parameterization will enable users
# to start future experiments from the same initialization.
next_evaluation_timestamp_seconds = None
await program_state_manager.save(
(train_state, start_round), version=start_round
ProgramState(
train_state, start_round, next_evaluation_timestamp_seconds
),
version=start_round,
)

train_state_type, _ = train_process.next.type_signature.result # pytype: disable=attribute-error
train_data_iterator = train_data_source.iterator()

# Track a future time after which an evaluation should be started. This will
# be `evaluation_periodicity` after the most recent evaluation time.
next_evaluation_time = None
next_evaluation_time = (
datetime.datetime.fromtimestamp(next_evaluation_timestamp_seconds)
if next_evaluation_timestamp_seconds
else None
)

def should_evaluate_round(
round_num: int, train_round_finished_time: datetime.datetime
Expand Down Expand Up @@ -251,8 +271,27 @@ def should_evaluate_round(
)
train_state = train_result.state
train_metrics = train_result.metrics

train_round_finished_time = datetime.datetime.now()
if evaluation_manager is not None and should_evaluate_round(
round_num, train_round_finished_time
):
model_weights = train_process.get_model_weights(train_state)
await evaluation_manager.start_evaluation(
round_num, int(train_round_finished_time.timestamp()), model_weights
)
logging.info('Added evaluation for training round %d', round_num)

next_evaluation_timestamp_seconds = (
int(next_evaluation_time.timestamp()) if next_evaluation_time else None
)
task_manager.add_task(
program_state_manager.save((train_state, round_num), version=round_num)
program_state_manager.save(
ProgramState(
train_state, round_num, next_evaluation_timestamp_seconds
),
version=round_num,
)
)
if train_metrics_manager is not None:
try:
Expand All @@ -276,15 +315,6 @@ def should_evaluate_round(
key=round_num,
)
)
train_round_finished_time = datetime.datetime.now()
if evaluation_manager is not None and should_evaluate_round(
round_num, train_round_finished_time
):
model_weights = train_process.get_model_weights(train_state)
await evaluation_manager.start_evaluation(
round_num, int(train_round_finished_time.timestamp()), model_weights
)
logging.info('Added evaluation for training round %d', round_num)

task_manager.add_task(
model_output_manager.release(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from tensorflow_federated.python.program import release_manager

# Convenience aliases.
ProgramState = training_program_logic.ProgramState
TensorType = computation_types.TensorType


Expand Down Expand Up @@ -225,16 +226,21 @@ async def return_round_num() -> int:
)
self.assertEqual(
mock_program_state_manager.load_latest.call_args_list,
[mock.call((any_algorithm_state, 0))],
)
[mock.call(ProgramState(any_algorithm_state, 0, 0))],
)
expected_state_manager_call_list = []
# Expect saving the initial state (version 0) and training rounds 1
# through training_rounds.
for round_num in range(0, training_rounds + 1):
expected_state_manager_call_list.append(
mock.call(
ProgramState(any_algorithm_state, round_num, None),
version=round_num,
)
)
self.assertEqual(
mock_program_state_manager.save.call_args_list,
# Expect saving the initial state (version 0) and training rounds 1
# through training_rounds.
[
mock.call((any_algorithm_state, round_num), version=round_num)
for round_num in range(0, training_rounds + 1)
],
expected_state_manager_call_list,
)

# Assert that training metrics were released every round.
Expand Down Expand Up @@ -312,15 +318,15 @@ async def return_round_num() -> None:
)

# Patch `datetime.now` so that each round looks like it takes 20
# milliseconds. With evaluation periodicity of 25 milliseconds and training
# seconds. With evaluation periodicity of 25 seconds and training
# rounds finishing (relative) at [0, 20, 40, 60, 80], the test will expect
# evaluations at round [1, 3, 5].
with mock.patch(
'datetime.datetime', wraps=datetime.datetime
) as mock_datetime:
start_datetime = datetime.datetime(2022, 11, 17, 9, 0)
mock_datetime.now.side_effect = [
start_datetime + datetime.timedelta(milliseconds=20 * i)
start_datetime + datetime.timedelta(seconds=20 * i)
for i in range(training_rounds)
]
await training_program_logic.train_model(
Expand All @@ -332,7 +338,7 @@ async def return_round_num() -> None:
model_output_manager=mock_model_output_manager,
evaluation_manager=mock_evaluation_manager,
train_metrics_manager=mock_train_metrics_manager,
evaluation_periodicity=datetime.timedelta(milliseconds=25),
evaluation_periodicity=datetime.timedelta(seconds=25),
)

# Assert that the program attempted to load a previous checkpoint and then
Expand All @@ -346,16 +352,32 @@ async def return_round_num() -> None:
)
self.assertEqual(
mock_program_state_manager.load_latest.call_args_list,
[mock.call((any_algorithm_state, 0))],
)
[mock.call(ProgramState(any_algorithm_state, 0, 0))],
)
# The next evaluation time of the first round is None. The last round will
# always be evaluated, and the next evaluation time won't be updated.
next_evaluation_timestamps = [None]
for relative_timestamp in [25, 25, 65, 65, 65]:
next_evaluation_timestamps.append(
relative_timestamp + int(start_datetime.timestamp())
)
# Expect saving the initial state (version 0) and training rounds 1
# through training_rounds.
expected_state_manager_call_list = []
for round_num in range(0, training_rounds + 1):
expected_state_manager_call_list.append(
mock.call(
ProgramState(
any_algorithm_state,
round_num,
next_evaluation_timestamps[round_num],
),
version=round_num,
)
)
self.assertEqual(
mock_program_state_manager.save.call_args_list,
# Expect saving the initial state (version 0) and training rounds 1
# through training_rounds.
[
mock.call((any_algorithm_state, round_num), version=round_num)
for round_num in range(0, training_rounds + 1)
],
expected_state_manager_call_list,
)

# Assert that training metrics were released every round.
Expand Down Expand Up @@ -432,16 +454,21 @@ async def test_integration_runs_5_training_rounds_no_eval_manager(self):
)
self.assertEqual(
mock_program_state_manager.load_latest.call_args_list,
[mock.call((any_algorithm_state, 0))],
)
[mock.call(ProgramState(any_algorithm_state, 0, 0))],
)
expected_state_manager_call_list = []
# Expect saving the initial state (version 0) and training rounds 1
# through training_rounds.
for round_num in range(0, training_rounds + 1):
expected_state_manager_call_list.append(
mock.call(
ProgramState(any_algorithm_state, round_num, None),
version=round_num,
)
)
self.assertEqual(
mock_program_state_manager.save.call_args_list,
# Expect saving the initial state (version 0) and training rounds 1
# through training_rounds.
[
mock.call((any_algorithm_state, round_num), version=round_num)
for round_num in range(0, training_rounds + 1)
],
expected_state_manager_call_list,
)

# Assert that training metrics were released every round.
Expand Down Expand Up @@ -509,11 +536,11 @@ async def test_program_state_manager_work_with_initial_state(self):
# given initial state and save it as version 0.
self.assertEqual(
mock_program_state_manager.load_latest.call_args_list,
[mock.call((initial_train_state, 0))],
[mock.call(ProgramState(initial_train_state, 0, 0))],
)
self.assertEqual(
mock_program_state_manager.save.call_args_list[0],
mock.call((initial_train_state, 0), version=0),
mock.call(ProgramState(initial_train_state, 0, None), version=0),
)

@context_stack_test_utils.with_context(_create_test_context)
Expand All @@ -529,7 +556,7 @@ async def test_resumes_from_previous_version_10_runs_one_round(self):
program_state_manager.ProgramStateManager, instance=True
)
mock_program_state_manager.load_latest.side_effect = [(
(training_state, training_rounds - 1),
ProgramState(training_state, training_rounds - 1, None),
training_rounds - 1,
)]

Expand Down Expand Up @@ -571,15 +598,16 @@ async def test_resumes_from_previous_version_10_runs_one_round(self):
)
self.assertEqual(
mock_program_state_manager.load_latest.call_args_list,
[mock.call((any_algorithm_state, 0))],
[mock.call(ProgramState(any_algorithm_state, 0, 0))],
)
self.assertEqual(
mock_program_state_manager.save.call_args_list,
# Expect saving the 11th version (running one round after loading the
# 10th version)
[
mock.call(
(any_algorithm_state, training_rounds), version=training_rounds
ProgramState(any_algorithm_state, training_rounds, None),
version=training_rounds,
)
],
msg=mock_program_state_manager.save.call_args_list,
Expand Down Expand Up @@ -620,7 +648,7 @@ async def test_resumes_from_previous_runs_no_train_rounds(self):
program_state_manager.ProgramStateManager, instance=True
)
mock_program_state_manager.load_latest.side_effect = [(
(training_state, training_rounds),
ProgramState(training_state, training_rounds, None),
training_rounds,
)]

Expand Down Expand Up @@ -660,7 +688,7 @@ async def test_resumes_from_previous_runs_no_train_rounds(self):
)
self.assertSequenceEqual(
mock_program_state_manager.load_latest.call_args_list,
[mock.call((any_algorithm_state, 0))],
[mock.call(ProgramState(any_algorithm_state, 0, 0))],
)
mock_program_state_manager.save.assert_not_called()

Expand Down

0 comments on commit e75b511

Please sign in to comment.