Skip to content

Commit

Permalink
add sharded dp allgather time
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Nov 13, 2023
1 parent 980a812 commit 9fb0a28
Showing 1 changed file with 97 additions and 64 deletions.
161 changes: 97 additions & 64 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __init__(
hbm_memory_efficiency: float = None,
intra_node_memory_efficiency: float = INTRA_NODE_MEMORY_EFFICIENCY,
inter_node_memory_efficiency: float = INTER_NODE_MEMORY_EFFICIENCY,
intra_node_alltoall_efficiency: float = INTRA_NODE_ALLTOALL_EFFICIENCY,
) -> None:
"""LLMAnalysis constructor.
Expand All @@ -117,7 +116,6 @@ def __init__(
self.dtype_config = dtype_config
self.intra_node_memory_efficiency = intra_node_memory_efficiency
self.inter_node_memory_efficiency = inter_node_memory_efficiency
self.intra_node_alltoall_efficiency = intra_node_alltoall_efficiency

if achieved_memory_bandwidth_GBs and hbm_memory_efficiency:
logger.info(
Expand Down Expand Up @@ -172,6 +170,10 @@ def __init__(
" parallelism")

self.total_num_params = self.get_num_params_total()
self.total_num_params_mlp = self.get_num_params_per_layer_mlp(
) * self.model_config.num_layers
self.total_num_params_embedding = self.get_num_params_embedding()
self.total_num_params_others = self.total_num_params - self.total_num_params_mlp - self.total_num_params_embedding
self.total_num_active_params = self.get_num_active_params_total()

def update_model_config(self, model_config: ModelConfig) -> None:
Expand Down Expand Up @@ -1043,20 +1045,24 @@ def get_latency_fwd_per_layer_attn(
logger.debug(
"latency_fwd_per_layer_attn:"
f" {round(max(compute_latency, memory_latency)*1000, 3)} ms"
" (max(compute_latency, weight_memory_latency+"
" activation_memory_latency) ="
" (max(compute_latency, weight_memory_latency +"
" activation_memory_latency):"
f" max({round(compute_latency*1000, 3)},"
f" ({round(weight_memory_latency*1000, 3)} +"
f" {round(activation_memory_latency*1000, 3)})))")
f" {round(weight_memory_latency*1000, 3)} +"
f" {round(activation_memory_latency*1000, 3)}))")

return max(compute_latency, memory_latency)

def get_latency_fwd_per_layer_mlp_moe_alltoall(self, batch_size: int,
seq_len: int) -> float:
data_nums = self.model_config.moe_top_k * batch_size * seq_len * self.model_config.hidden_dim
data_bytes = data_nums * self.dtype_config.activation_bits / BITS_PER_BYTE
return data_bytes / (self.gpu_config.intra_node_bandwidth_in_GB_per_sec
* self.intra_node_alltoall_efficiency * 10**9)

latency = data_bytes / (self.get_intra_node_bandwidth() * 10**9)
logger.info(
f'moe_alltoall data_bytes = {_num_to_string(data_bytes)}B, latency = {round(latency*1000, 3)} ms'
)
return latency

def get_latency_fwd_per_layer_mlp(
self,
Expand Down Expand Up @@ -1109,8 +1115,11 @@ def get_latency_fwd_per_layer_mlp(
" (max(compute_latency, weight_memory_latency+"
" activation_memory_latency) ="
f" max({round(compute_latency*1000, 3)},"
f" ({round(weight_memory_latency*1000, 3)} +"
f" {round(activation_memory_latency*1000, 3)})))")
f" {round(weight_memory_latency*1000, 3)} +"
f" {round(activation_memory_latency*1000, 3)}))")

logger.debug(
f'alltoall_latency = {round(alltoall_latency*1000, 3)} ms')

return max(compute_latency, memory_latency) + alltoall_latency

Expand Down Expand Up @@ -1196,7 +1205,7 @@ def time_allgather(S, n, B):
latency_allgather_params_mlp = time_allgather(
params_bytes_mlp, dp_size / ep_size,
(self.get_intra_node_bandwidth()
if ep_size <= 8 else self.get_inter_node_bandwidth()) * 10**9)
if ep_size < 8 else self.get_inter_node_bandwidth()) * 10**9)

latency_allgather_params_non_mlp = time_allgather(
params_bytes_non_mlp, dp_size,
Expand All @@ -1205,8 +1214,8 @@ def time_allgather(S, n, B):

latency_fwd_per_layer_shared_dp_comm = latency_allgather_params_mlp + latency_allgather_params_non_mlp

logger.debug(
f'params_bytes_mlp: {_num_to_string(params_bytes_mlp)}B, params_bytes_non_mlp: {_num_to_string(params_bytes_non_mlp)}B,latency_allgather_params_mlp: {round(latency_allgather_params_mlp*1000, 3)} ms, latency_allgather_params_non_mlp: {round(latency_allgather_params_non_mlp*1000, 3)} ms'
logger.info(
f'params_bytes_mlp: {_num_to_string(params_bytes_mlp)}B, params_bytes_non_mlp: {_num_to_string(params_bytes_non_mlp)}B, latency_allgather_params_mlp: {round(latency_allgather_params_mlp*1000, 3)} ms, latency_allgather_params_non_mlp: {round(latency_allgather_params_non_mlp*1000, 3)} ms'
)

return latency_fwd_per_layer_shared_dp_comm
Expand All @@ -1219,6 +1228,7 @@ def get_latency_fwd_per_layer(
activation_recomputation:
ActivationRecomputation = ActivationRecomputation.NONE,
layernorm_dtype_bytes: int = BYTES_FP32,
ds_zero: DSZeRO = DSZeRO.NONE,
) -> tuple:
"""Get the latency for the forward pass of a transformer layer, given the batch
size, sequence length, training or inference, activation recomputation strategy,
Expand All @@ -1232,10 +1242,14 @@ def get_latency_fwd_per_layer(
is_inference (bool, optional): whether it is inference or not. Defaults to True.
activation_recomputation (ActivationRecomputation, optional): activation recomputation strategy. Defaults to ActivationRecomputation.NONE.
layernorm_dtype_bytes (int, optional): number of bytes in the data type for the layernorm activations. Defaults to BYTES_FP32. Often has to be FP32 in training to maintain model accuracy.
ds_zero (DSZeRO, optional): which DeepSpeed ZeRO stage to use. Defaults to DSZeRO.NONE (disabled).
Returns:
tuple: a tuple of the latency in seconds for the forward pass of a transformer layer and its breakdown dict
"""
if ds_zero != ds_zero.NONE:
assert not is_inference, "DeepSpeed ZeRO is only supported in training"

latency_fwd_per_layer_attn = self.get_latency_fwd_per_layer_attn(
batch_size, seq_len, is_inference, activation_recomputation)

Expand All @@ -1258,35 +1272,36 @@ def get_latency_fwd_per_layer(
seq_len,
self.dtype_config.activation_bits / BITS_PER_BYTE,
)

latency_fwd_per_layer_shared_dp_comm = self.get_latency_fwd_per_layer_shared_dp_comm(
)
logger.debug(
f"latency_fwd_per_layer_tp_comm: {round(latency_fwd_per_layer_tp_comm*1000, 3)} ms"
)

latency_per_layer = (latency_fwd_per_layer_attn +
latency_fwd_per_layer_mlp +
2 * latency_fwd_per_layer_layernorm +
2 * latency_fwd_per_layer_tp_comm +
latency_fwd_per_layer_shared_dp_comm)
logger.debug(
f"latency_fwd_per_layer_shared_dp_comm: {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)} ms"
latency_fwd_per_layer_shared_dp_comm = self.get_latency_fwd_per_layer_shared_dp_comm(
)

logger.debug(
f"latency_per_layer: {round(latency_per_layer*1000, 3)} ms"
f" ({round(latency_fwd_per_layer_attn*1000, 3)} +"
latency_per_layer = latency_fwd_per_layer_attn + latency_fwd_per_layer_mlp + 2 * latency_fwd_per_layer_layernorm + 2 * latency_fwd_per_layer_tp_comm

if ds_zero > DSZeRO.STAGE_1 and latency_fwd_per_layer_shared_dp_comm > latency_per_layer:
logger.warning(
f'allgather communication time to unshard model weight {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)} ms is larger than compute {round(latency_per_layer*1000, 3)} ms, thus cannot be fully overlapped.'
)
latency_per_layer = max(latency_per_layer,
latency_fwd_per_layer_shared_dp_comm)

logger.info(
f"latency_per_layer: {round(latency_per_layer*1000, 3)} ms (max(attn + mlp + 2*layernorm + 2*tp_comm, shared_dp_comm):"
f" max({round(latency_fwd_per_layer_attn*1000, 3)} +"
f" {round(latency_fwd_per_layer_mlp*1000, 3)} +"
f" {round(2*latency_fwd_per_layer_layernorm*1000, 3)} +"
f" {round(2*latency_fwd_per_layer_tp_comm*1000, 3)} +"
f" {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)})")
f" {round(2*latency_fwd_per_layer_tp_comm*1000, 3)},"
f" {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)}))")

breakdown_per_layer = {
"attn": latency_fwd_per_layer_attn,
"mlp": latency_fwd_per_layer_mlp,
"layernorm": 2 * latency_fwd_per_layer_layernorm,
"tp_comm": 2 * latency_fwd_per_layer_tp_comm,
"sharded_dp_comm": latency_fwd_per_layer_shared_dp_comm
}

return latency_per_layer, breakdown_per_layer
Expand Down Expand Up @@ -1341,6 +1356,7 @@ def get_latency_fwd(
ActivationRecomputation = ActivationRecomputation.NONE,
layernorm_dtype_bytes: int = BYTES_FP32,
breakdown_prefix: str = "",
ds_zero: DSZeRO = DSZeRO.NONE,
) -> tuple:
"""Get the latency for the forward pass of the transformer, given the batch
size, sequence length, and whether it is inference or not, the activation
Expand All @@ -1354,6 +1370,7 @@ def get_latency_fwd(
activation_recomputation (ActivationRecomputation, optional): activation recomputation strategy. Defaults to ActivationRecomputation.NONE.
layernorm_dtype_bytes (int, optional): number of bytes in the data type for the layernorm activations. Defaults to BYTES_FP32. Often has to be FP32 in training to maintain model accuracy.
breakdown_prefix (str, optional): prefix for the breakdown dict keys. Defaults to "".
ds_zero (DSZeRO, optional): which DeepSpeed ZeRO stage to use. Defaults to DSZeRO.NONE (disabled).
Returns:
tuple: a tuple of the latency in seconds for the forward pass of the transformer and its breakdown dict
"""
Expand All @@ -1369,9 +1386,10 @@ def get_latency_fwd(
is_inference,
activation_recomputation,
layernorm_dtype_bytes,
ds_zero,
)

latency_fwd_all_layers = latency_fwd_per_layer * num_layers_per_gpu
latency_fwd_layers = latency_fwd_per_layer * num_layers_per_gpu

latency_fwd_input_embedding = self.get_latency_fwd_input_embedding(
batch_size,
Expand All @@ -1382,23 +1400,23 @@ def get_latency_fwd(
latency_fwd_output_embedding_loss = (
self.get_latency_fwd_output_embedding_loss(batch_size, seq_len))

total_latency = (latency_fwd_all_layers + latency_fwd_input_embedding +
latency_fwd_output_embedding_loss)
latency_fwd = (latency_fwd_layers + latency_fwd_input_embedding +
latency_fwd_output_embedding_loss)

logger.debug("latency_fwd_all_layers:"
f" {round(latency_fwd_all_layers*1000, 3)} ms"
f" ({round(latency_fwd_per_layer*1000, 3)} ms x"
f" {num_layers_per_gpu}), latency_fwd_input_embedding:"
f" {round(latency_fwd_input_embedding*1000, 3)} ms,"
" latency_fwd_output_embedding_loss:"
f" {round(latency_fwd_output_embedding_loss*1000, 3)} ms")
logger.info("latency_fwd_layers:"
f" {round(latency_fwd_layers*1000, 3)} ms"
f" ({round(latency_fwd_per_layer*1000, 3)} ms x"
f" {num_layers_per_gpu}), latency_fwd_input_embedding:"
f" {round(latency_fwd_input_embedding*1000, 3)} ms,"
" latency_fwd_output_embedding_loss:"
f" {round(latency_fwd_output_embedding_loss*1000, 3)} ms")

logger.debug(f"latency_fwd_total: {round(total_latency*1000, 3)} ms"
f" ({round(latency_fwd_all_layers*1000, 3)} +"
f" {round(latency_fwd_input_embedding*1000, 3)} +"
f" {round(latency_fwd_output_embedding_loss*1000, 3)})")
logger.info(
f"latency_fwd: {round(latency_fwd*1000, 3)} ms (layers + input_embedding + output_embedding_loss: "
f"{round(latency_fwd_layers*1000, 3)} + {round(latency_fwd_input_embedding*1000, 3)} + {round(latency_fwd_output_embedding_loss*1000, 3)})"
)

total_breakdown = {
latency_fwd_breakdown = {
breakdown_prefix + "latency_fwd_attn":
breakdown_per_layer["attn"] * num_layers_per_gpu,
breakdown_prefix + "latency_fwd_mlp":
Expand All @@ -1407,12 +1425,18 @@ def get_latency_fwd(
breakdown_per_layer["layernorm"] * num_layers_per_gpu,
breakdown_prefix + "latency_fwd_tp_comm":
breakdown_per_layer["tp_comm"] * num_layers_per_gpu,
breakdown_prefix + "latency_fwd_sharded_dp_comm":
breakdown_per_layer["sharded_dp_comm"] * num_layers_per_gpu,
breakdown_prefix + "latency_fwd_input_embedding":
latency_fwd_input_embedding,
breakdown_prefix + "latency_fwd_output_embedding_loss":
latency_fwd_output_embedding_loss,
}
return total_latency, total_breakdown
return latency_fwd, latency_fwd_breakdown

def get_latency_weight_update(self, ):
return self.weight_grad_op_state_memory_per_gpu / (
self.get_gpu_hbm_bandwidth() * 10**9)

def print_config(self, name="Training Configs") -> None:
config_str = f"\n{name.center(PRINT_LINE_WIDTH, '-')}\n"
Expand Down Expand Up @@ -1985,9 +2009,10 @@ def training(
optimizer_state_memory_per_gpu = optimizer_state_memory_per_layer * num_layers_per_gpu + optimizer_state_memory_embedding + optimizer_state_memory_last_layernorm
gradient_memory_per_gpu = gradient_memory_per_layer * num_layers_per_gpu + gradient_memory_embedding + gradient_memory_last_layernorm

self.weight_grad_op_state_memory_per_gpu = weight_memory_per_gpu + gradient_memory_per_gpu + optimizer_state_memory_per_gpu

memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 -
weight_memory_per_gpu - optimizer_state_memory_per_gpu -
gradient_memory_per_gpu)
self.weight_grad_op_state_memory_per_gpu)

logger.info(
f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B"
Expand Down Expand Up @@ -2142,24 +2167,24 @@ def training(
is_inference=False,
activation_recomputation=activation_recomputation,
layernorm_dtype_bytes=layernorm_dtype_bytes,
ds_zero=ds_zero,
)
# estimated by flops only:
# latency_per_micro_batch = num_flops_total_per_micro_batch / (
# (self.parallelism_config.tp_size * self.parallelism_config.pp_size)
# * self.get_TFLOPS_per_gpu() * 1e12)
latency_per_micro_batch = latency_fwd * 3
latency_weight_update = self.get_latency_weight_update()

mp_size = (self.parallelism_config.tp_size *
self.parallelism_config.pp_size)

latency_per_micro_batch = num_flops_total_per_micro_batch / (
mp_size * self.get_TFLOPS_per_gpu() * 1e12)

latency_per_iter = (latency_per_micro_batch *
gradient_accumulation_steps)
latency_per_iter = (
latency_per_micro_batch * gradient_accumulation_steps +
latency_weight_update)

logger.info(
"latency_per_micro_batch:"
f" {round(latency_per_micro_batch * 1000, 3)} ms, latency_fwd:"
f" {round(latency_fwd * 1000, 3)} ms, \nlatency_per_iter:"
f" {round(latency_per_iter * 1000, 3)} ms"
f" ({round(latency_per_micro_batch * 1000, 3)} ms *"
f" {gradient_accumulation_steps} gradient_accumulation_steps)")
f"latency_per_micro_batch: {round(latency_per_micro_batch * 1000, 3)} ms, "
f"latency_per_iter: {round(latency_per_iter * 1000, 3)} ms "
f"({round(latency_per_micro_batch * 1000, 3)} ms latency_fwd * {gradient_accumulation_steps} gradient_accumulation_steps + {round(latency_weight_update * 1000, 3)} ms weight_update)"
)

total_num_gpus = (self.parallelism_config.tp_size *
self.parallelism_config.pp_size *
Expand Down Expand Up @@ -2230,6 +2255,12 @@ def training(
total_num_tokens,
"num_params_total":
self.total_num_params,
"num_params_total_mlp":
self.total_num_params_mlp,
"num_params_total_embedding":
self.total_num_params_embedding,
"num_params_total_others":
self.total_num_params_others,
"num_active_params_total":
self.total_num_active_params,
"activation_recomputation":
Expand Down Expand Up @@ -2263,8 +2294,7 @@ def training(
"optimizer_state_memory_per_gpu":
optimizer_state_memory_per_gpu,
"(weight+op_state+grad)_memory_per_gpu":
weight_memory_per_gpu + gradient_memory_per_gpu +
optimizer_state_memory_per_gpu,
self.weight_grad_op_state_memory_per_gpu,
"activation_memory_batch_size_1":
activation_memory_batch_size_1,
"activation_memory_per_gpu":
Expand All @@ -2278,8 +2308,8 @@ def training(
"activation_memory_embedding_output_per_gpu":
activation_memory_embedding_output_per_gpu,
"(weight+op_state+grad+act)_memory_per_gpu":
weight_memory_per_gpu + gradient_memory_per_gpu +
optimizer_state_memory_per_gpu + activation_memory_per_gpu,
self.weight_grad_op_state_memory_per_gpu +
activation_memory_per_gpu,
"memory_left_per_gpu":
memory_left,
"latency_per_micro_batch":
Expand All @@ -2290,6 +2320,7 @@ def training(
summary_dict.update(latency_fwd_breakdown)
summary_dict.update({
"latency_per_iter": latency_per_iter,
"iters_per_sec": round(1 / latency_per_iter, 2),
"total_training_latency": total_training_latency,
"gpu_hours": gpu_hours,
})
Expand Down Expand Up @@ -2511,6 +2542,8 @@ def train(
else:
dp_size = 1

assert ep_size <= 8, "only support ep_size up to 8 GPUs per node"

model_config = get_model_config_by_name(model_name)
gpu_config = get_gpu_config_by_name(gpu_name)
dtype_config = get_dtype_config_by_name(dtype_name)
Expand Down

0 comments on commit 9fb0a28

Please sign in to comment.