From c68c8d681e9263224e35ca0cc453446006dd62a8 Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:47:57 +0800 Subject: [PATCH] [Patch] Add a `clear_cache_freq` option in the commandline. (#63) * . * . --- realhf/experiments/common/common.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/realhf/experiments/common/common.py b/realhf/experiments/common/common.py index 2204b1d7..27eb8a0b 100644 --- a/realhf/experiments/common/common.py +++ b/realhf/experiments/common/common.py @@ -150,6 +150,9 @@ class CommonExperimentConfig(Experiment): :type nodelist: str or None :param seed: Random seed. :type seed: int + :param cache_clear_freq: The cache of data transfer will be cleared after each ``cache_clear_freq`` steps. + If None, will not clear the cache. Set to a small number, e.g., 1, if OOM or CUDA OOM occurs. + :type cache_clear_freq: int or None :param exp_ctrl: The save and evaluation control of the experiment. :type exp_ctrl: ExperimentSaveEvalControl """ @@ -172,6 +175,7 @@ class CommonExperimentConfig(Experiment): n_gpus_per_node: int = 8 nodelist: Optional[str] = None seed: int = 1 + cache_clear_freq: Optional[int] = 10 exp_ctrl: ExperimentSaveEvalControl = dataclasses.field( default_factory=ExperimentSaveEvalControl ) @@ -413,8 +417,8 @@ def _get_model_worker_configs( seed=self.seed, shards=[], datasets=self.datasets, - cuda_cache_cleanliness=False, - cuda_cache_clear_freq=10, + cuda_cache_cleanliness=self.cache_clear_freq is not None, + cuda_cache_clear_freq=self.cache_clear_freq, tokenizer_name_or_path=self.tokenizer_name_or_path, )