-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Sangkug Lym <[email protected]>
- Loading branch information
Showing
3 changed files
with
78 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
# convergence (e.g., 300B tokens) is not guaranteed. | ||
defaults: | ||
- _self_ | ||
- optional [email protected]_tp_comm_overlap_cfg: ub_cfg_h100_h12288_tp8_mbs2_seqlen2048 | ||
- optional [email protected]_tp_comm_overlap_cfg: ub_cfg_h100_h12288_tp4_mbs2_seqlen2048 | ||
|
||
hydra: | ||
searchpath: | ||
|
@@ -16,7 +16,7 @@ run: | |
dependency: "singleton" | ||
|
||
trainer: | ||
num_nodes: 128 | ||
num_nodes: 64 | ||
devices: 8 | ||
accelerator: gpu | ||
precision: bf16 | ||
|
@@ -61,7 +61,7 @@ model: | |
micro_batch_size: 2 | ||
global_batch_size: 2048 | ||
context_parallel_size: 1 | ||
tensor_model_parallel_size: 8 | ||
tensor_model_parallel_size: 4 | ||
pipeline_model_parallel_size: 8 | ||
virtual_pipeline_model_parallel_size: 6 # interleaved pipeline, set to maximum | ||
resume_from_checkpoint: null # manually set the checkpoint file to load from | ||
|
@@ -73,16 +73,20 @@ model: | |
ffn_hidden_size: ${multiply:4, ${.hidden_size}} # Transformer FFN hidden size. 4 * hidden_size. | ||
num_attention_heads: 96 | ||
init_method_std: 0.006 # Standard deviation of the zero mean normal distribution used for weight initialization.') | ||
hidden_dropout: 0.1 # Dropout probability for hidden state transformer. | ||
attention_dropout: 0.1 | ||
use_scaled_init_method: false | ||
hidden_dropout: 0.0 # Dropout probability for hidden state transformer. | ||
attention_dropout: 0.0 | ||
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null | ||
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. | ||
apply_query_key_layer_scaling: false # scale Q * K^T by 1 / layer-number. | ||
layernorm_epsilon: 1e-5 | ||
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. | ||
pre_process: True # add embedding | ||
post_process: True # add pooler | ||
persist_layer_norm: True # Use of persistent fused layer norm kernel. | ||
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) | ||
normalization: layernorm1p | ||
do_layer_norm_weight_decay: true | ||
bias: true | ||
|
||
# Fusion | ||
grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce | ||
|
@@ -145,8 +149,7 @@ model: | |
fp8_interval: 1 # scaling update interval | ||
fp8_amax_history_len: 1024 # Number of steps for which amax history is recorded per tensor | ||
fp8_amax_compute_algo: max # 'most_recent' or 'max'. Algorithm for computing amax from history | ||
fp8_wgrad: True | ||
ub_tp_comm_overlap: False | ||
ub_tp_comm_overlap: True | ||
|
||
# miscellaneous | ||
seed: 1234 | ||
|
@@ -168,10 +171,11 @@ model: | |
|
||
optim: | ||
name: distributed_fused_adam | ||
bucket_cap_mb: 220 | ||
bucket_cap_mb: 125 | ||
overlap_grad_sync: True | ||
overlap_param_sync: true | ||
contiguous_grad_buffer: True | ||
contiguous_param_buffer: True | ||
grad_sync_dtype: bf16 | ||
lr: 0.9e-4 | ||
weight_decay: 0.1 | ||
|
@@ -257,3 +261,5 @@ model: | |
- .0334 | ||
- ${data_dir}/my-gpt3_29_text_document | ||
|
||
env_vars: | ||
NVTE_FUSED_ATTN: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
launcher_scripts/conf/training/tp_overlap/ub_cfg_h100_h12288_tp4_mbs2_seqlen2048.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# UB communicator configurations | ||
# Model configs: H100/175B/TP4/MBS1/SeqLen2K/FP8 | ||
|
||
# Bulk overlap with AllGather / ReduceScatter | ||
qkv_dgrad: | ||
method: bulk | ||
num_sm: 4 | ||
cga_size: 2 | ||
set_sm_margin: 0 | ||
|
||
qkv_wgrad: | ||
method: bulk | ||
num_sm: 4 | ||
cga_size: 2 | ||
set_sm_margin: 0 | ||
|
||
fc1_dgrad: | ||
method: bulk | ||
num_sm: 2 | ||
cga_size: 2 | ||
set_sm_margin: 0 | ||
|
||
fc1_wgrad: | ||
method: bulk | ||
num_sm: 4 | ||
cga_size: 2 | ||
set_sm_margin: 0 | ||
|
||
## Ring-exchange overlap with AllGather | ||
qkv_fprop: | ||
method: ring_exchange | ||
aggregate: 0 | ||
|
||
proj_dgrad: | ||
method: ring_exchange | ||
aggregate: 0 | ||
|
||
fc1_fprop: | ||
method: ring_exchange | ||
aggregate: 0 | ||
|
||
fc2_dgrad: | ||
method: ring_exchange | ||
aggregate: 0 | ||
|
||
# Chunked-collective overlap with ReduceScatter | ||
proj_fprop: | ||
method: pipeline | ||
num_sm: 24 | ||
cga_size: 2 | ||
num_splits: 4 | ||
set_sm_margin: 1 | ||
fp8_buf: 1 | ||
|
||
fc2_fprop: | ||
method: ring_exchange | ||
num_sm: 1 | ||
set_sm_margin: 1 |