diff --git a/realhf/experiments/benchmark/profile_exp.py b/realhf/experiments/benchmark/profile_exp.py index 537c23c9..aef08a56 100644 --- a/realhf/experiments/benchmark/profile_exp.py +++ b/realhf/experiments/benchmark/profile_exp.py @@ -193,22 +193,23 @@ def initial_setup(self) -> List[ExperimentConfig]: f"will be saved to: {setup_log_path}" ) with open(setup_log_path, "w") as f: + # batch size in the most outer loop to delay the possible OOM error for ( + bs, pcfg, n_mbs, model_cfg, dataset_cfg, handle_name, interface_cfg, - bs, ) in itertools.product( + self.batch_sizes, self.parallel_kwargs, self.n_mbs, self.model_kwargs, self.dataset_kwargs, self.handle_names, self.interface_kwargs, - self.batch_sizes, ): if handle_name == "generate" and pcfg["use_sequence_parallel"]: continue