Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett4wade committed Aug 21, 2024
1 parent 8289fa0 commit 34287df
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 12 deletions.
15 changes: 7 additions & 8 deletions examples/profiling/profile_ppo_critic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,24 @@ export CLUSTER_SPEC_PATH="/lustre/aigc/llm/cluster/qh.json"

# Changing `model.init_critic_from_actor` and `model.type.is_critic` is importance for profiling the critic.

REAL_DUMP_KERNEL_TIME=0 REAL_DUMP_TRACE=0 REAL_DUMP_MEMORY=0 \
REAL_DUMP_KERNEL_TIME=1 REAL_DUMP_TRACE=1 REAL_DUMP_MEMORY=1 \
python3 -m realhf.apps.quickstart profile \
mode=local \
experiment_name=$EXP_NAME \
trial_name=$TRIAL_NAME \
exp_ctrl.benchmark_steps=5 \
exp_ctrl.benchmark_steps=3 \
exp_ctrl.save_freq_steps=null \
exp_ctrl.eval_freq_steps=null \
n_nodes=1 \
model.type._class=$MODEL_FAMILY \
model.path=$SFT_MODEL_PATH \
dataset.path=/lustre/fw/datasets/imdb/rl/ppo_prompt.jsonl \
dataset.max_prompt_len=1024 \
dataset.train_bs_n_seqs=64 \
dataset.max_prompt_len=4096 \
dataset.train_bs_n_seqs=256 \
dataset.pad_to_max_length=True \
handle_name=inference \
handle_name=train_step \
interface_impl=ppo_critic \
model.init_critic_from_actor=True \
model.type.is_critic=True \
'n_mbs=[1, 2, 4]' \
interface_kwargs_json=./examples/profiling/interfaces/ppo_critic.json \
allocations_jsonl=./examples/profiling/allocations/local.jsonl
'n_mbs=[2]' \
interface_kwargs_json=./examples/profiling/interfaces/ppo_critic.json
10 changes: 10 additions & 0 deletions realhf/base/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ def dump_tmark_db(worker_idx):
"backward_kernel",
"reduce_kernel",
"multi_tensor_apply",
"gae_kernel",
"gemvx::kernel",
"cublas",
]

COMM_KERNEL_KEYS = [
Expand All @@ -474,6 +477,7 @@ def dump_tmark_db(worker_idx):

MEM_KERNEL_KEYS = [
"Memcpy",
"cleanup",
"Memset",
]

Expand All @@ -482,6 +486,10 @@ def dump_tmark_db(worker_idx):
"CudaCodeGen",
]

IGNORE_KERNEL_KEYS = [
"FusedAdam", # This is a marker above multi-tensor-apply
]


@dataclasses.dataclass
class CUDAKernelTime: # in us
Expand Down Expand Up @@ -518,6 +526,8 @@ def from_profiler(cls, p):
mem_time += x.self_cuda_time_total
elif any(k in x.key for k in MISC_KERNEL_KEYS):
misc_time += x.self_cuda_time_total
elif any(k in x.key for k in IGNORE_KERNEL_KEYS):
continue
else:
unknown_keys.append(x)
if unknown_keys:
Expand Down
6 changes: 5 additions & 1 deletion realhf/experiments/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def make_train_backend_config(
raise ValueError("Offload is not supported in Megatron backend.")
if model_cfg.zero_stage == 3:
raise ValueError("Zero stage 3 is not supported in Megatron backend.")
if model_cfg.zero_stage == 2:
logger.warning("Megatron does not ZeRO stage 2. Degenerates to stage 1.")
model_cfg.zero_stage = 1
return ModelBackendAbstraction(
"megatron",
args=dict(
Expand All @@ -100,9 +103,10 @@ def make_train_backend_config(
min_lr_ratio=model_cfg.optimizer.min_lr_ratio,
enable_bf16=model_cfg.enable_bf16,
enable_fp16=model_cfg.enable_fp16,
# See MegatronTrainBackend for detailed explanations about these options.
use_zero_optimization=model_cfg.zero_stage > 0,
overlap_grad_reduce=model_cfg.zero_stage > 0,
overlap_param_gather=model_cfg.zero_stage > 0,
overlap_param_gather=False,
),
)
else:
Expand Down
46 changes: 45 additions & 1 deletion realhf/impl/model/backend/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,50 @@ def generate(

@dataclasses.dataclass
class MegatronTrainBackend(model_api.ModelBackend):
"""
When using the DistributedOptimizer of Megatron, parameters and gradients
will not be splitted across DP ranks, but optimizer states will be.
In other words, Megatron only supports ZeRO-1.
Megatron DDP will split the whole flattend parameter into buckets.
Buckets do not respect parameter boundaries and are dispatched to different DP ranks.
The optimizer on a specific DP rank will only manage its own bucket,
but parameters and gradients are held by all ranks and will not be further splitted.
(That's why only optimizer states are partitioned.) During backward, bucket gradients
will be scatter-reduced (controlled by the `use_distributed_optimizer` option
in Megatron DDP, otherwise all-reduce will be issued), and parameters will then
be updated locally. At this point, the parameters are not synced across DP ranks.
The DistributedOptimizer will then call all-gather on parameters.
Since Megatron allocates static tensors for scatter-reducing parameter gradients,
it does not decrease memory usage just as DeepSpeed ZeRO-2. To be more specific,
with dynamic allocation, we can allocate gradient memory layer-by-layer. When the
backward finishes at layer N, we can scatter-reduce gradients and release the memory
after scattering. As a result, given DP size K, layer number L, and parameter size P
for each layer, dynamic allocation requires P * (1 + L/K) memory for gradients,
but Megatron requires P * L. Memory is not freed after scattering in Megatron.
'use_distributed_optimizer' enables bucketing and scatter-reduce gradients.
When setting to False, optimizer states will not be partitioned.
'overlap_grad_reduce' enables issuing all-reduce/scatter-reduce on the fly
during bacwkard once the gradient is ready, which should usually be enabled.
'overlap_param_gather' overlaps param all-gather with the next forward pass.
It creates a forward hook that waits for the previous parameter all-gather
after the optimizer step. While this sounds good, it can be problematic with
parameter reallocation, because the reallocated parameters do not have the hook.
Can be enabled for SFT, but should be disabled for PPO.
As a final note, Megatron is in an awkward place for PPO with param-realloc.
First, it does not minimize the memory usage of gradients (i.e., ZeRO-2).
Second, for functional correctness, we can't enable `overlap_param_gather`,
and a parameter update will be scatter-reduce grad + all-gather param, instead
of an all-reduce (running all-reduce requires setting `use_distributed_optimizer`
to False, but that will not partition optimizer states!), so it is not that
efficient, either. We use Megatron because it is the only backend that we can
make it functionally correct. The DeepSpeed code is too hard to read and modify.
"""
optimizer_name: str = dataclasses.field(
metadata={"choices": ["adam", "sgd"]},
default="adam",
Expand All @@ -852,7 +896,7 @@ class MegatronTrainBackend(model_api.ModelBackend):
enable_bf16: bool = False
use_zero_optimization: bool = True
overlap_grad_reduce: bool = True
overlap_param_gather: bool = True
overlap_param_gather: bool = False
accumulate_allreduce_grads_in_fp32: bool = False
initial_loss_scale: float = 4096.0
gradient_clipping: float = 1.0
Expand Down
4 changes: 2 additions & 2 deletions realhf/system/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def __maybe_profile_rpc(self, rpc: dfg.MFCDef):
f"Collecting system metrics from the profiler. "
"This may take for a while..."
)
parallel_str = f"d{self._dp_size}m{self._mp_size}p{self._pp_size}"
parallel_str = f"mb{rpc.n_mbs}d{self._dp_size}m{self._mp_size}p{self._pp_size}"
if constants.sequence_parallel():
parallel_str += "sp"
if _dump_kernel_time:
Expand Down Expand Up @@ -736,7 +736,7 @@ def __maybe_profile_rpc(self, rpc: dfg.MFCDef):
)
os.makedirs(mem_trace_dir, exist_ok=True)
torch.cuda.memory._dump_snapshot(
os.path.join(mem_trace_dir, f"mw{self.__worker_index}.pkl")
os.path.join(mem_trace_dir, f"{rpc.name}_r{dist.get_rank()}.pkl")
)

def __handle_model_function_calls(
Expand Down

0 comments on commit 34287df

Please sign in to comment.