From 34287df0e948f98f312b54f99490a1d072172bbf Mon Sep 17 00:00:00 2001 From: fw Date: Wed, 21 Aug 2024 15:35:47 +0000 Subject: [PATCH] . --- examples/profiling/profile_ppo_critic.sh | 15 ++++---- realhf/base/monitor.py | 10 ++++++ realhf/experiments/common/utils.py | 6 +++- realhf/impl/model/backend/megatron.py | 46 +++++++++++++++++++++++- realhf/system/model_worker.py | 4 +-- 5 files changed, 69 insertions(+), 12 deletions(-) diff --git a/examples/profiling/profile_ppo_critic.sh b/examples/profiling/profile_ppo_critic.sh index d2cd7f0d..22f4a0d1 100644 --- a/examples/profiling/profile_ppo_critic.sh +++ b/examples/profiling/profile_ppo_critic.sh @@ -35,25 +35,24 @@ export CLUSTER_SPEC_PATH="/lustre/aigc/llm/cluster/qh.json" # Changing `model.init_critic_from_actor` and `model.type.is_critic` is importance for profiling the critic. -REAL_DUMP_KERNEL_TIME=0 REAL_DUMP_TRACE=0 REAL_DUMP_MEMORY=0 \ +REAL_DUMP_KERNEL_TIME=1 REAL_DUMP_TRACE=1 REAL_DUMP_MEMORY=1 \ python3 -m realhf.apps.quickstart profile \ mode=local \ experiment_name=$EXP_NAME \ trial_name=$TRIAL_NAME \ - exp_ctrl.benchmark_steps=5 \ + exp_ctrl.benchmark_steps=3 \ exp_ctrl.save_freq_steps=null \ exp_ctrl.eval_freq_steps=null \ n_nodes=1 \ model.type._class=$MODEL_FAMILY \ model.path=$SFT_MODEL_PATH \ dataset.path=/lustre/fw/datasets/imdb/rl/ppo_prompt.jsonl \ - dataset.max_prompt_len=1024 \ - dataset.train_bs_n_seqs=64 \ + dataset.max_prompt_len=4096 \ + dataset.train_bs_n_seqs=256 \ dataset.pad_to_max_length=True \ - handle_name=inference \ + handle_name=train_step \ interface_impl=ppo_critic \ model.init_critic_from_actor=True \ model.type.is_critic=True \ - 'n_mbs=[1, 2, 4]' \ - interface_kwargs_json=./examples/profiling/interfaces/ppo_critic.json \ - allocations_jsonl=./examples/profiling/allocations/local.jsonl + 'n_mbs=[2]' \ + interface_kwargs_json=./examples/profiling/interfaces/ppo_critic.json diff --git a/realhf/base/monitor.py b/realhf/base/monitor.py index 81698a31..aaf150da 100755 --- a/realhf/base/monitor.py +++ b/realhf/base/monitor.py @@ -463,6 +463,9 @@ def dump_tmark_db(worker_idx): "backward_kernel", "reduce_kernel", "multi_tensor_apply", + "gae_kernel", + "gemvx::kernel", + "cublas", ] COMM_KERNEL_KEYS = [ @@ -474,6 +477,7 @@ def dump_tmark_db(worker_idx): MEM_KERNEL_KEYS = [ "Memcpy", + "cleanup", "Memset", ] @@ -482,6 +486,10 @@ def dump_tmark_db(worker_idx): "CudaCodeGen", ] +IGNORE_KERNEL_KEYS = [ + "FusedAdam", # This is a marker above multi-tensor-apply +] + @dataclasses.dataclass class CUDAKernelTime: # in us @@ -518,6 +526,8 @@ def from_profiler(cls, p): mem_time += x.self_cuda_time_total elif any(k in x.key for k in MISC_KERNEL_KEYS): misc_time += x.self_cuda_time_total + elif any(k in x.key for k in IGNORE_KERNEL_KEYS): + continue else: unknown_keys.append(x) if unknown_keys: diff --git a/realhf/experiments/common/utils.py b/realhf/experiments/common/utils.py index 0182fa0d..5da5ae5f 100644 --- a/realhf/experiments/common/utils.py +++ b/realhf/experiments/common/utils.py @@ -82,6 +82,9 @@ def make_train_backend_config( raise ValueError("Offload is not supported in Megatron backend.") if model_cfg.zero_stage == 3: raise ValueError("Zero stage 3 is not supported in Megatron backend.") + if model_cfg.zero_stage == 2: + logger.warning("Megatron does not ZeRO stage 2. Degenerates to stage 1.") + model_cfg.zero_stage = 1 return ModelBackendAbstraction( "megatron", args=dict( @@ -100,9 +103,10 @@ def make_train_backend_config( min_lr_ratio=model_cfg.optimizer.min_lr_ratio, enable_bf16=model_cfg.enable_bf16, enable_fp16=model_cfg.enable_fp16, + # See MegatronTrainBackend for detailed explanations about these options. use_zero_optimization=model_cfg.zero_stage > 0, overlap_grad_reduce=model_cfg.zero_stage > 0, - overlap_param_gather=model_cfg.zero_stage > 0, + overlap_param_gather=False, ), ) else: diff --git a/realhf/impl/model/backend/megatron.py b/realhf/impl/model/backend/megatron.py index 41c19731..cf5543e1 100644 --- a/realhf/impl/model/backend/megatron.py +++ b/realhf/impl/model/backend/megatron.py @@ -836,6 +836,50 @@ def generate( @dataclasses.dataclass class MegatronTrainBackend(model_api.ModelBackend): + """ + When using the DistributedOptimizer of Megatron, parameters and gradients + will not be splitted across DP ranks, but optimizer states will be. + In other words, Megatron only supports ZeRO-1. + + Megatron DDP will split the whole flattend parameter into buckets. + Buckets do not respect parameter boundaries and are dispatched to different DP ranks. + The optimizer on a specific DP rank will only manage its own bucket, + but parameters and gradients are held by all ranks and will not be further splitted. + (That's why only optimizer states are partitioned.) During backward, bucket gradients + will be scatter-reduced (controlled by the `use_distributed_optimizer` option + in Megatron DDP, otherwise all-reduce will be issued), and parameters will then + be updated locally. At this point, the parameters are not synced across DP ranks. + The DistributedOptimizer will then call all-gather on parameters. + + Since Megatron allocates static tensors for scatter-reducing parameter gradients, + it does not decrease memory usage just as DeepSpeed ZeRO-2. To be more specific, + with dynamic allocation, we can allocate gradient memory layer-by-layer. When the + backward finishes at layer N, we can scatter-reduce gradients and release the memory + after scattering. As a result, given DP size K, layer number L, and parameter size P + for each layer, dynamic allocation requires P * (1 + L/K) memory for gradients, + but Megatron requires P * L. Memory is not freed after scattering in Megatron. + + 'use_distributed_optimizer' enables bucketing and scatter-reduce gradients. + When setting to False, optimizer states will not be partitioned. + + 'overlap_grad_reduce' enables issuing all-reduce/scatter-reduce on the fly + during bacwkard once the gradient is ready, which should usually be enabled. + + 'overlap_param_gather' overlaps param all-gather with the next forward pass. + It creates a forward hook that waits for the previous parameter all-gather + after the optimizer step. While this sounds good, it can be problematic with + parameter reallocation, because the reallocated parameters do not have the hook. + Can be enabled for SFT, but should be disabled for PPO. + + As a final note, Megatron is in an awkward place for PPO with param-realloc. + First, it does not minimize the memory usage of gradients (i.e., ZeRO-2). + Second, for functional correctness, we can't enable `overlap_param_gather`, + and a parameter update will be scatter-reduce grad + all-gather param, instead + of an all-reduce (running all-reduce requires setting `use_distributed_optimizer` + to False, but that will not partition optimizer states!), so it is not that + efficient, either. We use Megatron because it is the only backend that we can + make it functionally correct. The DeepSpeed code is too hard to read and modify. + """ optimizer_name: str = dataclasses.field( metadata={"choices": ["adam", "sgd"]}, default="adam", @@ -852,7 +896,7 @@ class MegatronTrainBackend(model_api.ModelBackend): enable_bf16: bool = False use_zero_optimization: bool = True overlap_grad_reduce: bool = True - overlap_param_gather: bool = True + overlap_param_gather: bool = False accumulate_allreduce_grads_in_fp32: bool = False initial_loss_scale: float = 4096.0 gradient_clipping: float = 1.0 diff --git a/realhf/system/model_worker.py b/realhf/system/model_worker.py index 6cd03700..c083191e 100755 --- a/realhf/system/model_worker.py +++ b/realhf/system/model_worker.py @@ -688,7 +688,7 @@ def __maybe_profile_rpc(self, rpc: dfg.MFCDef): f"Collecting system metrics from the profiler. " "This may take for a while..." ) - parallel_str = f"d{self._dp_size}m{self._mp_size}p{self._pp_size}" + parallel_str = f"mb{rpc.n_mbs}d{self._dp_size}m{self._mp_size}p{self._pp_size}" if constants.sequence_parallel(): parallel_str += "sp" if _dump_kernel_time: @@ -736,7 +736,7 @@ def __maybe_profile_rpc(self, rpc: dfg.MFCDef): ) os.makedirs(mem_trace_dir, exist_ok=True) torch.cuda.memory._dump_snapshot( - os.path.join(mem_trace_dir, f"mw{self.__worker_index}.pkl") + os.path.join(mem_trace_dir, f"{rpc.name}_r{dist.get_rank()}.pkl") ) def __handle_model_function_calls(