diff --git a/realhf/experiments/common/dpo_exp.py b/realhf/experiments/common/dpo_exp.py index f56e926c..f18f8ddb 100755 --- a/realhf/experiments/common/dpo_exp.py +++ b/realhf/experiments/common/dpo_exp.py @@ -55,10 +55,6 @@ class DPOConfig(CommonExperimentConfig): :type dataset: PairedComparisonDatasetConfig :param beta: KL regularization coefficient. :type beta: float - :param actor_train_n_mbs: Number of microbatches for training the primary LLM. - :type actor_train_n_mbs: int - :param ref_inf_n_mbs: Number of microbatches for inference on the reference LLM. - :type ref_inf_n_mbs: int """ is_sft_lora: bool = False diff --git a/realhf/experiments/common/ppo_exp.py b/realhf/experiments/common/ppo_exp.py index eecd20c0..8fd58d54 100755 --- a/realhf/experiments/common/ppo_exp.py +++ b/realhf/experiments/common/ppo_exp.py @@ -171,18 +171,6 @@ class PPOConfig(CommonExperimentConfig): :type dataset: PromptOnlyDatasetConfig :param ppo: Configuration for the PPO algorithm. :type ppo: PPOHyperparameters - :param actor_train_n_mbs: Number of minibatches for TrainActor. - :type actor_train_n_mbs: int - :param critic_train_n_mbs: Number of minibatches for TrainCritic. - :type critic_train_n_mbs: int - :param actor_gen_n_mbs: Number of minibatches for Rollout. - :type actor_gen_n_mbs: int - :param critic_inf_n_mbs: Number of minibatches for InfValues. - :type critic_inf_n_mbs: int - :param rew_inf_n_mbs: Number of minibatches for InfReward. - :type rew_inf_n_mbs: int - :param ref_inf_n_mbs: Number of minibatches for InfRef. - :type ref_inf_n_mbs: int """ is_sft_lora: bool = False diff --git a/realhf/experiments/common/rw_exp.py b/realhf/experiments/common/rw_exp.py index 5ce43f32..902bbe39 100755 --- a/realhf/experiments/common/rw_exp.py +++ b/realhf/experiments/common/rw_exp.py @@ -37,8 +37,6 @@ class RWConfig(CommonExperimentConfig): :type allocation: MFCConfig :param dataset: Dataset configuration. :type dataset: PairedComparisonDatasetConfig - :param n_mbs: Number of microbatches. - :type n_mbs: int """ is_sft_lora: bool = False diff --git a/realhf/experiments/common/sft_exp.py b/realhf/experiments/common/sft_exp.py index 78b66a19..c82d5278 100755 --- a/realhf/experiments/common/sft_exp.py +++ b/realhf/experiments/common/sft_exp.py @@ -28,8 +28,6 @@ class SFTConfig(CommonExperimentConfig): :type allocation: MFCConfig :param dataset: Dataset configuration :type dataset: PromptAnswerDatasetConfig - :param n_mbs: Number of microbatches. - :type n_mbs: int """ model: ModelTrainEvalConfig = dataclasses.field( diff --git a/realhf/impl/model/backend/inference.py b/realhf/impl/model/backend/inference.py index 0b48a8ad..3f1f1ede 100644 --- a/realhf/impl/model/backend/inference.py +++ b/realhf/impl/model/backend/inference.py @@ -93,7 +93,9 @@ def forward( if num_micro_batches is None: num_micro_batches = 1 outputs = [] - for mb_input in input_.split(num_micro_batches): + for mb_input in input_.split( + num_micro_batches, min_size=constants.pipe_parallel_world_size() + ): if constants.pipe_parallel_world_size() > 1: model_output = self.pipe_runner.forward( input_=mb_input, @@ -138,7 +140,9 @@ def generate( if num_micro_batches is None: num_micro_batches = 1 sequences, scores, logits_mask = [], [], [] - for mb_input in input_.split(num_micro_batches): + for mb_input in input_.split( + num_micro_batches, min_size=constants.pipe_parallel_world_size() + ): if constants.pipe_parallel_world_size() > 1: res = self.pipe_runner.generate( input_=mb_input, diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index 94aa74ec..5a252157 100755 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -367,7 +367,7 @@ def _attach_payloads_with_hooks( return payloads, mwids -async def _request_model_function_call( +def _request_model_function_call( rpc: dfg.MFCDef, stream: request_reply_stream.NameResolvingRequestClient, msid2mwid: Dict[config_pkg.ModelShardID, int], @@ -438,14 +438,10 @@ async def _request_model_function_call( ) req_ids = [stream.post(p) for h, p in payloads.items() if h in handlers] other_req_ids = [stream.post(p) for h, p in payloads.items() if h not in handlers] - await asyncio.gather( - *[ - _awaitable_response( - stream, pattern=create_exact_match_pattern([p.syn_reply_id]) - ) - for p in payloads.values() - ] - ) + [ + stream.poll(block=True, pattern=create_exact_match_pattern([p.syn_reply_id])) + for p in payloads.values() + ] [ stream.post( request_reply_stream.Payload( @@ -585,7 +581,7 @@ async def model_rpc_request_func( producer_mappings[names[0], k] = producer_mapping # send partitioned data to model workers - req_ids, other_req_ids = await _request_model_function_call( + req_ids, other_req_ids = _request_model_function_call( rpc=rpc, stream=stream, msid2mwid=msid2mwid, @@ -984,9 +980,17 @@ def __lazy_init(self): # Build some data required for subsequent model function calls. self.__all_model_handlers: List[config_pkg.ModelShardID] = [] + self.__all_mw_handlers: List[config_pkg.ModelShardID] = [] + _covered_mws = set() self.__dp0_model_handlers: List[config_pkg.ModelShardID] = [] self.__trainable_model_handlers: List[config_pkg.ModelShardID] = [] for model_name, topo in self.config.model_topos.items(): + for j in range(topo.world_size()): + h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j) + _mw_id = self.config.msid2mwid[h] + if _mw_id not in _covered_mws: + _covered_mws.add(_mw_id) + self.__all_mw_handlers.append(h) num_dp = topo.get_dim("data") self.__all_model_handlers += [ config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j) @@ -1487,9 +1491,9 @@ def _log_training_stats(self, e2e_time: float, time_since_configure: float): def _clear_gpu_cache(self): request_all( self.__stream, - [vs[0] for vs in self.__mwid2msids.values()], + self.__all_mw_handlers, "clear_data_cache", - [self.__rpc_ctrl.ids_to_clear for _ in self.__all_model_handlers], + [self.__rpc_ctrl.ids_to_clear for _ in self.__all_mw_handlers], ) self.__rpc_ctrl.ids_to_clear.clear() @@ -1515,9 +1519,9 @@ def experiment_complete_exit(self): # Model workers will not respond to this message. request_all( self.__stream, - handlers=self.__all_model_handlers, + handlers=self.__all_mw_handlers, handle_type="reset", - datas=[None for _ in self.__all_model_handlers], + datas=[None for _ in self.__all_mw_handlers], ) self.__stream.close() constants.reset_run() diff --git a/realhf/system/model_worker.py b/realhf/system/model_worker.py index a6bd8873..385d1f2b 100755 --- a/realhf/system/model_worker.py +++ b/realhf/system/model_worker.py @@ -365,7 +365,7 @@ def __lazy_setup(self): eval_dataloader = None self.__eval_dataloaders[s.id.model_name] = eval_dataloader - self.__request_cache = [] + self.__request_cache = {} self.__ack_cache = {} self.__request_queue = queue.Queue(maxsize=8) @@ -607,6 +607,7 @@ def model_poll_step( del self.__data_sent_worker_indices[_id] if _id in self.__data_received_worker_indices: del self.__data_received_worker_indices[_id] + gc.collect() if ( self.config.cuda_cache_cleanliness and self.__clear_cache_frequency.check() @@ -852,24 +853,19 @@ def __maybe_receive_one_request(self): handle_name="syn", ), ) - self.__request_cache.append(r) + self.__request_cache[r.ack_reply_id] = r except request_reply_stream.NoMessage: return @cuda_tmark("receive_request", CUDATimeMarkType.misc) def maybe_receive_requests(self): - for _ in range(16): - self.__maybe_receive_one_request() - - while len(self.__request_cache) > 0: - request: request_reply_stream.Payload = self.__request_cache[0] - while request.ack_reply_id not in self.__ack_cache: - self.__maybe_receive_one_request() - - self.__ack_cache.pop(request.ack_reply_id) - self.__request_cache.pop(0) - - self.__request_queue.put_nowait((request, request.data, False, None)) + self.__maybe_receive_one_request() + cur_ack_ids = list(self.__ack_cache.keys()) + for ack_id in cur_ack_ids: + if ack_id in self.__request_cache: + self.__ack_cache.pop(ack_id) + req = self.__request_cache.pop(ack_id) + self.__request_queue.put_nowait((req, req.data, False, None)) def _poll(self): if not self.__dist_env_resolved: