From 8b2c2d26ba359beaca03b35eaf5c272ff221c9cd Mon Sep 17 00:00:00 2001 From: fw Date: Tue, 27 Aug 2024 01:30:42 +0000 Subject: [PATCH 1/2] fix --- realhf/system/master_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index c7a5aa16..dcd89bd0 100755 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -551,7 +551,7 @@ async def model_rpc_request_func( assert sample.bs % dp_size == 0 min_n_seqs_per_dp = sample.bs // dp_size else: - min_n_seqs_per_dp = rpc.n_mbs + min_n_seqs_per_dp = rpc.n_mbs if rpc.n_mbs is not None else 1 split_spec = sample.get_split_spec(dp_size, min_size=min_n_seqs_per_dp) partitions = split_spec.partitions target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)} From 6adc95cb0a2611c268d649337fc7fba9b97aaefe Mon Sep 17 00:00:00 2001 From: fw Date: Tue, 27 Aug 2024 03:45:43 +0000 Subject: [PATCH 2/2] add save_eval_timeout at the last step --- realhf/api/core/system_api.py | 7 +++++++ realhf/system/master_worker.py | 6 ++++++ 2 files changed, 13 insertions(+) diff --git a/realhf/api/core/system_api.py b/realhf/api/core/system_api.py index 84002c4d..b7c10247 100755 --- a/realhf/api/core/system_api.py +++ b/realhf/api/core/system_api.py @@ -184,6 +184,11 @@ class ExperimentSaveEvalControl: :param benchmark_steps: Terminate the training after this number of steps. Used by system benchmark only. Please leave it to None for normal training. :type benchmark_steps: Optional[int] + :param save_eval_timeout: Timeout in seconds for saving and evaluation. + Will be used for the last step of the experiment. The master worker will sleep + for `save_eval_timeout` seconds to wait all save or evaluations to finish. + Defaults to 120 seconds. + :type save_eval_timeout: int """ total_train_epochs: int = 1 @@ -197,6 +202,8 @@ class ExperimentSaveEvalControl: eval_freq_secs: Optional[int] = None # benchmark benchmark_steps: Optional[int] = None + # Graceful exit + save_eval_timeout: int = 120 @dataclasses.dataclass diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index dcd89bd0..16c614a4 100755 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -1344,6 +1344,12 @@ def _poll(self): if is_new_epoch: if self._epoch > self.__total_train_epochs: + if should_eval or should_save: + logger.info( + f"Waiting for all save/eval requests at the last step" + f" for {self.config.exp_ctrl.save_eval_timeout} secs..." + ) + time.sleep(self.config.exp_ctrl.save_eval_timeout) self.experiment_complete_exit(f"Training completes! Yeah!!!") total_time_consumption = time.perf_counter() - self._train_start_time