Skip to content

Commit

Permalink
Add GPT 175B mlperf config
Browse files Browse the repository at this point in the history
Signed-off-by: Sangkug Lym <[email protected]>
  • Loading branch information
erhoo82 committed Jul 19, 2024
1 parent f1da8cc commit e908312
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# convergence (e.g., 300B tokens) is not guaranteed.
defaults:
- _self_
- optional [email protected]_tp_comm_overlap_cfg: ub_cfg_h100_h12288_tp8_mbs2_seqlen2048
- optional [email protected]_tp_comm_overlap_cfg: ub_cfg_h100_h12288_tp4_mbs2_seqlen2048

hydra:
searchpath:
Expand All @@ -16,7 +16,7 @@ run:
dependency: "singleton"

trainer:
num_nodes: 128
num_nodes: 64
devices: 8
accelerator: gpu
precision: bf16
Expand Down Expand Up @@ -61,7 +61,7 @@ model:
micro_batch_size: 2
global_batch_size: 2048
context_parallel_size: 1
tensor_model_parallel_size: 8
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 8
virtual_pipeline_model_parallel_size: 6 # interleaved pipeline, set to maximum
resume_from_checkpoint: null # manually set the checkpoint file to load from
Expand All @@ -73,16 +73,20 @@ model:
ffn_hidden_size: ${multiply:4, ${.hidden_size}} # Transformer FFN hidden size. 4 * hidden_size.
num_attention_heads: 96
init_method_std: 0.006 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1
use_scaled_init_method: false
hidden_dropout: 0.0 # Dropout probability for hidden state transformer.
attention_dropout: 0.0
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
apply_query_key_layer_scaling: false # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
normalization: layernorm1p
do_layer_norm_weight_decay: true
bias: true

# Fusion
grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce
Expand Down Expand Up @@ -145,8 +149,7 @@ model:
fp8_interval: 1 # scaling update interval
fp8_amax_history_len: 1024 # Number of steps for which amax history is recorded per tensor
fp8_amax_compute_algo: max # 'most_recent' or 'max'. Algorithm for computing amax from history
fp8_wgrad: True
ub_tp_comm_overlap: False
ub_tp_comm_overlap: True

# miscellaneous
seed: 1234
Expand All @@ -168,10 +171,11 @@ model:

optim:
name: distributed_fused_adam
bucket_cap_mb: 220
bucket_cap_mb: 125
overlap_grad_sync: True
overlap_param_sync: true
contiguous_grad_buffer: True
contiguous_param_buffer: True
grad_sync_dtype: bf16
lr: 0.9e-4
weight_decay: 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ qkv_dgrad:

qkv_wgrad:
method: bulk
num_sm: 8
num_sm: 4
cga_size: 2
set_sm_margin: 0

Expand Down Expand Up @@ -41,7 +41,7 @@ fc1_fprop:

fc2_dgrad:
method: ring_exchange
aggregate: 1
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
Expand All @@ -50,10 +50,9 @@ proj_fprop:
cga_size: 2
num_splits: 4
set_sm_margin: 1
fp8_buf: 1

fc2_fprop:
method: pipeline
num_sm: 20
cga_size: 2
num_splits: 4
method: ring_exchange
num_sm: 1
set_sm_margin: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# UB communicator configurations
# Model configs: H100/175B/TP4/MBS1/SeqLen2K/FP8

# Bulk overlap with AllGather / ReduceScatter
qkv_dgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
cga_size: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 24
cga_size: 2
num_splits: 4
set_sm_margin: 1
fp8_buf: 1

fc2_fprop:
method: ring_exchange
num_sm: 1
set_sm_margin: 1

0 comments on commit e908312

Please sign in to comment.