From 6601ec3d5965dcb3c8290eb9608dc0e6becf7383 Mon Sep 17 00:00:00 2001 From: fw Date: Wed, 28 Aug 2024 07:50:37 +0000 Subject: [PATCH 1/2] . --- realhf/api/core/system_api.py | 7 ---- realhf/system/master_worker.py | 60 ++++++++++++++++++++++++++++------ 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/realhf/api/core/system_api.py b/realhf/api/core/system_api.py index b7c10247..84002c4d 100755 --- a/realhf/api/core/system_api.py +++ b/realhf/api/core/system_api.py @@ -184,11 +184,6 @@ class ExperimentSaveEvalControl: :param benchmark_steps: Terminate the training after this number of steps. Used by system benchmark only. Please leave it to None for normal training. :type benchmark_steps: Optional[int] - :param save_eval_timeout: Timeout in seconds for saving and evaluation. - Will be used for the last step of the experiment. The master worker will sleep - for `save_eval_timeout` seconds to wait all save or evaluations to finish. - Defaults to 120 seconds. - :type save_eval_timeout: int """ total_train_epochs: int = 1 @@ -202,8 +197,6 @@ class ExperimentSaveEvalControl: eval_freq_secs: Optional[int] = None # benchmark benchmark_steps: Optional[int] = None - # Graceful exit - save_eval_timeout: int = 120 @dataclasses.dataclass diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index 16c614a4..ca80ee0f 100755 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -148,7 +148,7 @@ async def gather_all_replies( return responses -async def group_rpc_blocked( +async def async_group_rpc( stream: request_reply_stream.NameResolvingRequestClient, handlers: List[Union[config_pkg.ModelShardID, str]], handle_type: str, @@ -162,6 +162,21 @@ async def group_rpc_blocked( return [p.data for p in payloads] +def group_rpc_blocked( + stream: request_reply_stream.NameResolvingRequestClient, + handlers: List[Union[config_pkg.ModelShardID, str]], + handle_type: str, + datas: List, + verbose: bool = True, +): + req_ids = request_all(stream, handlers, handle_type, datas, verbose=verbose) + payloads = [ + stream.poll(pattern=create_exact_match_pattern([req_id]), block=True) + for req_id in req_ids + ] + return [p.data for p in payloads] + + def _request_parameter_sync( stream: request_reply_stream.NameResolvingRequestClient, msid2mwid: Dict[config_pkg.ModelShardID, int], @@ -705,7 +720,7 @@ async def load_data_func( while not is_final_batch: # Send request to model workers to get the specification of data. # Data itself is not transferred to the master worker. - data_batches: List[data_api.DataBatchMeta] = await group_rpc_blocked( + data_batches: List[data_api.DataBatchMeta] = await async_group_rpc( stream, handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)], handle_type="fetch", @@ -807,12 +822,12 @@ async def model_eval_thread_func( ): while not stop_ctl.is_set(): epoch, epoch_step = await eval_queue.get() - eval_stats = await group_rpc_blocked( + eval_stats = await async_group_rpc( stream, handlers, "evaluate", [None for _ in handlers] ) eval_stats = _gather_stat(list(filter(lambda x: bool(x), eval_stats))) logger.info( - f"Evaluation results at epoch {epoch + 1} step {epoch_step + 1}: {eval_stats}" + f"Evaluation results at epoch {epoch} step {epoch_step}: {eval_stats}" ) @@ -833,7 +848,7 @@ async def model_save_thread_func( ) for s in handlers ] - await group_rpc_blocked(stream, handlers, "save", model_save_dirs) + await async_group_rpc(stream, handlers, "save", model_save_dirs) logger.info(f"Save models at epoch {epoch} step {epoch_step}.") @@ -1130,7 +1145,7 @@ def __lazy_init(self): ) _task = event_loop.create_task( - group_rpc_blocked( + async_group_rpc( self.__stream, handlers=_handlers, handle_type="initialize", @@ -1344,12 +1359,37 @@ def _poll(self): if is_new_epoch: if self._epoch > self.__total_train_epochs: - if should_eval or should_save: + if should_eval: + eval_stats = group_rpc_blocked( + self.__stream, + self.__all_model_handlers, + "evaluate", + [None for _ in self.__all_model_handlers], + ) + eval_stats = _gather_stat( + list(filter(lambda x: bool(x), eval_stats)) + ) + logger.info( + f"Evaluation results at epoch {self._epoch} step {self._epoch_step}: {eval_stats}" + ) + if should_save: + model_save_dirs = [ + os.path.join( + self.MODEL_SAVE_ROOT, + s.model_name.role, + f"epoch{self._epoch}epochstep{self._epoch_step}globalstep{self._global_step}", + ) + for s in self.__trainable_model_handlers + ] + group_rpc_blocked( + self.__stream, + self.__trainable_model_handlers, + "save", + model_save_dirs, + ) logger.info( - f"Waiting for all save/eval requests at the last step" - f" for {self.config.exp_ctrl.save_eval_timeout} secs..." + f"Save models at epoch {self._epoch} step {self._epoch_step}." ) - time.sleep(self.config.exp_ctrl.save_eval_timeout) self.experiment_complete_exit(f"Training completes! Yeah!!!") total_time_consumption = time.perf_counter() - self._train_start_time From fad2214c390265f4be7adb34de2d9a02cb66bef9 Mon Sep 17 00:00:00 2001 From: fw Date: Wed, 28 Aug 2024 08:38:35 +0000 Subject: [PATCH 2/2] fix gen exp --- examples/scripts/local/gen.sh | 4 +- realhf/api/core/data_api.py | 5 +- realhf/experiments/common/gen_exp.py | 33 ++++- realhf/impl/model/interface/gen_interface.py | 128 +++++++++++++++++-- 4 files changed, 152 insertions(+), 18 deletions(-) diff --git a/examples/scripts/local/gen.sh b/examples/scripts/local/gen.sh index fced444a..9d4c19e2 100644 --- a/examples/scripts/local/gen.sh +++ b/examples/scripts/local/gen.sh @@ -16,9 +16,9 @@ python3 -m realhf.apps.quickstart gen \ allocation_mode=manual \ model.type._class=$MODEL_FAMILY \ model.path=$SFT_MODEL_PATH \ - dataset.path=/lustre/fw/datasets/imdb/rl/ppo_prompt.jsonl \ + dataset.path=/lustre/fw/datasets/imdb/rl/ppo_prompt-small.jsonl \ dataset.max_prompt_len=1024 \ - dataset.train_bs_n_seqs=128 \ + dataset.train_bs_n_seqs=100 \ allocation.parallel.pipeline_parallel_size=1 \ allocation.parallel.model_parallel_size=2 \ allocation.parallel.data_parallel_size=4 \ diff --git a/realhf/api/core/data_api.py b/realhf/api/core/data_api.py index c0980426..9c467c12 100755 --- a/realhf/api/core/data_api.py +++ b/realhf/api/core/data_api.py @@ -48,7 +48,10 @@ def load_hf_tokenizer( if padding_side is not None: kwargs["padding_side"] = padding_side tokenizer = transformers.AutoTokenizer.from_pretrained( - model_name_or_path, fast_tokenizer=fast_tokenizer, **kwargs + model_name_or_path, + fast_tokenizer=fast_tokenizer, + trust_remote_code=True, + **kwargs, ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/realhf/experiments/common/gen_exp.py b/realhf/experiments/common/gen_exp.py index ee33316b..daad2e58 100644 --- a/realhf/experiments/common/gen_exp.py +++ b/realhf/experiments/common/gen_exp.py @@ -6,6 +6,7 @@ ModelInterfaceType, ModelName, ) +from realhf.api.core.data_api import DatasetUtility, load_hf_tokenizer from realhf.api.core.dfg import MFCDef from realhf.api.core.model_api import GenerationHyperparameters from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig @@ -43,6 +44,35 @@ class GenerationConfig(CommonExperimentConfig): ) allocation: MFCConfig = dataclasses.field(default_factory=MFCConfig) + output_file: str = "output.jsonl" + + def __post_init__(self): + from realhf.impl.dataset.prompt_dataset import PromptDataset + + util = DatasetUtility( + seed=0, + ddp_rank=0, + world_size=1, + tokenizer=load_hf_tokenizer(self.model.path), + ) + d = PromptDataset( + util, max_length=self.dataset.max_prompt_len, dataset_path=self.dataset.path + ) + if len(d) % self.dataset.train_bs_n_seqs != 0: + raise ValueError( + f"The size of the dataset must be a multiple of batch size for generation. " + f"Otherwise the final batch will be dropped. Please pad your dataset size with random prompts. " + f"Current dataset size: {len(d)}, batch size: {self.dataset.train_bs_n_seqs}." + ) + if self.output_file is not None: + if not self.output_file.endswith(".jsonl"): + raise ValueError("Output path must end with .jsonl") + if "/" in self.output_file: + raise ValueError( + "Output path must not contain '/'. It should be a simple " + "filename that will be saved to the logging directory." + ) + @property def models(self): return { @@ -52,7 +82,8 @@ def models(self): @property def rpcs(self): interface = ModelInterfaceAbstraction( - "generation", args={"generation_config": self.gen} + "generation", + args={"generation_config": self.gen, "output_file": self.output_file}, ) gen = MFCDef( name="gen", diff --git a/realhf/impl/model/interface/gen_interface.py b/realhf/impl/model/interface/gen_interface.py index 5e5bb151..64522559 100644 --- a/realhf/impl/model/interface/gen_interface.py +++ b/realhf/impl/model/interface/gen_interface.py @@ -1,4 +1,7 @@ import dataclasses +import fcntl +import json +import os import colorama import torch @@ -6,12 +9,35 @@ import realhf.api.core.model_api as model_api from realhf.api.core.data_api import SequenceSample from realhf.base import constants, logging +from realhf.base.datapack import flat2d logger = logging.getLogger("Generation Interface", "benchmark") +def acquire_lock(lock_file): + fd = open(lock_file, "w") + fcntl.flock(fd, fcntl.LOCK_EX) + return fd + + +def release_lock(lock_fd): + fcntl.flock(lock_fd, fcntl.LOCK_UN) + lock_fd.close() + + +def write_dict_to_jsonl(dict_data, file_path, lock_file): + lock_fd = acquire_lock(lock_file) + try: + with open(file_path, "a") as file: + json.dump(dict_data, file) + file.write("\n") + finally: + release_lock(lock_fd) + + @dataclasses.dataclass class GenerationInterface(model_api.ModelInterface): + output_file: str | None = None generation_config: model_api.GenerationHyperparameters = dataclasses.field( default_factory=model_api.GenerationHyperparameters ) @@ -44,27 +70,101 @@ def generate( if res is None: return None - gen_tokens, logprobs, *_ = res + gen_tokens, *_ = res - # Decode and log the first generated sentence. - l = input_.seqlens["packed_prompts"][0][0] - tokens = torch.cat( - [input_.data["packed_prompts"][:l], gen_tokens[0]] - ).unsqueeze(0) - out = model.tokenizer.batch_decode( - tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - if constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage(): + res = { + "generated_length": gen_tokens.shape[1], + "batch_size": gen_tokens.shape[0], + } + if not ( + constants.model_parallel_rank() == 0 and constants.is_last_pipe_stage() + ): + # Not DP head, return stats. + return res + + if self.output_file is not None: + + # Concatenate prompts with gen_tokens, decode, and output to file. + prompt_lens = flat2d(input_.seqlens["packed_prompts"]) + gen_lengths = (gen_tokens != model.tokenizer.pad_token_id).logical_and( + gen_tokens != model.tokenizer.eos_token_id + ).sum(dim=-1) + 1 + gen_lengths = gen_lengths.clip(max=gen_tokens.shape[-1]) + assert len(gen_lengths) == len(prompt_lens) == input_.bs, ( + input_.bs, + len(prompt_lens), + len(gen_lengths), + ) + + prompt_tokens_lis = [] + ans_tokens_lis = [] + prompt_offset = 0 + for i, (prompt_len, gen_len) in enumerate(zip(prompt_lens, gen_lengths)): + prompt_tokens_lis.append( + input_.data["packed_prompts"][ + prompt_offset : prompt_offset + prompt_len + ] + ) + ans_tokens_lis.append(gen_tokens[i, :gen_len]) + prompt_offset += prompt_len + assert prompt_offset == sum(prompt_lens) + seq_tokens_lis = [ + torch.cat([a, b]) for a, b in zip(prompt_tokens_lis, ans_tokens_lis) + ] + + prompt_str = model.tokenizer.batch_decode( + prompt_tokens_lis, + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + ) + ans_str = model.tokenizer.batch_decode( + ans_tokens_lis, + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + ) + seq_str = model.tokenizer.batch_decode( + seq_tokens_lis, + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + ) + + lock_file = os.path.join( + constants.LOG_ROOT, + constants.experiment_name(), + constants.trial_name(), + "_gen.lock", + ) + output_file = os.path.join( + constants.LOG_ROOT, + constants.experiment_name(), + constants.trial_name(), + self.output_file, + ) + if constants.data_parallel_rank() == 0: + logger.info(f"Dumping output to: {output_file}...") + for p, a, s, _id in zip(prompt_str, ans_str, seq_str, input_.ids): + d = dict( + prompt=p, + answer=a, + seq=s, + id=_id, + ) + write_dict_to_jsonl(d, output_file, lock_file) + else: + # Decode and log the first generated sentence. + l = input_.seqlens["packed_prompts"][0][0] + tokens = torch.cat( + [input_.data["packed_prompts"][:l], gen_tokens[0]] + ).unsqueeze(0) + out = model.tokenizer.batch_decode( + tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) dp_rank = constants.data_parallel_rank() logger.info( f"DP rank {dp_rank}, the first generated sequence " f"is: {colorama.Fore.YELLOW + colorama.Style.DIM}{out[0]}{colorama.Style.RESET_ALL}" ) - res = { - "generated_length": gen_tokens.shape[1], - "batch_size": gen_tokens.shape[0], - } return res