From 168d9a3718076a7d1598988c013c4a435228aab5 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 12 Nov 2024 23:38:59 -0800 Subject: [PATCH] fix allreduce latency and mem usage when tp is in use --- llm_analysis/analysis.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 20100c5..32b1cb6 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -475,9 +475,9 @@ def get_memory_optimizer_state_and_gradient_per_layer( memory_optimizer_state_others_per_layer = op_bytes_per_params * ( (self.get_num_params_per_layer_attn() + - +self.get_num_params_per_layer_router() + - self.get_num_params_per_layer_layernorm()) - ) / self.parallelism_config.tp_size / sharded_dp_size + +self.get_num_params_per_layer_router()) / + self.parallelism_config.tp_size + + self.get_num_params_per_layer_layernorm()) / sharded_dp_size memory_optimizer_state_per_layer = memory_optimizer_state_mlp_per_layer + memory_optimizer_state_others_per_layer @@ -1218,9 +1218,9 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int, elems_per_all_reduce = (2 * batch_size * seq_len * self.model_config.hidden_dim * (tp_size - 1) / tp_size) - latency_per_all_reduce = ( - elems_per_all_reduce * dtype_bytes / - (self.gpu_config.intra_node_bandwidth_in_GB_per_sec * 10**9)) + # assuming tp_size <= number of GPUs per node, thus using intra-node bandwidth + latency_per_all_reduce = (elems_per_all_reduce * dtype_bytes / + (self.get_intra_node_bandwidth() * 10**9)) return max( latency_per_all_reduce,