Skip to content

Commit

Permalink
fix allreduce latency and memory usage calculation when using tp (#28)
Browse files Browse the repository at this point in the history
* fix allreduce latency and mem usage when tp is in use

* update latency calcuation in allgather
  • Loading branch information
cli99 authored Nov 13, 2024
1 parent dfd4da9 commit d841e40
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,9 @@ def get_memory_optimizer_state_and_gradient_per_layer(

memory_optimizer_state_others_per_layer = op_bytes_per_params * (
(self.get_num_params_per_layer_attn() +
+self.get_num_params_per_layer_router() +
self.get_num_params_per_layer_layernorm())
) / self.parallelism_config.tp_size / sharded_dp_size
+self.get_num_params_per_layer_router()) /
self.parallelism_config.tp_size +
self.get_num_params_per_layer_layernorm()) / sharded_dp_size

memory_optimizer_state_per_layer = memory_optimizer_state_mlp_per_layer + memory_optimizer_state_others_per_layer

Expand Down Expand Up @@ -1218,9 +1218,9 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int,
elems_per_all_reduce = (2 * batch_size * seq_len *
self.model_config.hidden_dim * (tp_size - 1) /
tp_size)
latency_per_all_reduce = (
elems_per_all_reduce * dtype_bytes /
(self.gpu_config.intra_node_bandwidth_in_GB_per_sec * 10**9))
# assuming tp_size <= number of GPUs per node, thus using intra-node bandwidth
latency_per_all_reduce = (elems_per_all_reduce * dtype_bytes /
(self.get_intra_node_bandwidth() * 10**9))

return max(
latency_per_all_reduce,
Expand All @@ -1230,6 +1230,7 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int,
def get_latency_fwd_per_layer_shared_dp_comm(self) -> float:
dp_size = self.parallelism_config.dp_size
ep_size = self.parallelism_config.ep_size
tp_size = self.parallelism_config.tp_size

def time_allgather(S, n, B):
# https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allgather
Expand All @@ -1243,15 +1244,17 @@ def time_allgather(S, n, B):
self.get_num_params_per_layer_layernorm()
) * self.dtype_config.weight_bits / BITS_PER_BYTE

latency_allgather_params_mlp = time_allgather(
params_bytes_mlp, dp_size / ep_size,
(self.get_intra_node_bandwidth()
if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9)
# assuming tp and dp are preferred when sharding intra node, pp is only applied across nodes
# when (dp_size * tp_size) <= 8, the data parallel processes are within a node
bandwidth = self.get_intra_node_bandwidth() if (
dp_size * tp_size) <= 8 else self.get_inter_node_bandwidth()

latency_allgather_params_mlp = time_allgather(params_bytes_mlp,
dp_size / ep_size,
bandwidth * 10**9)

latency_allgather_params_non_mlp = time_allgather(
params_bytes_non_mlp, dp_size,
(self.get_intra_node_bandwidth()
if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9)
params_bytes_non_mlp, dp_size, bandwidth * 10**9)

latency_fwd_per_layer_shared_dp_comm = latency_allgather_params_mlp + latency_allgather_params_non_mlp

Expand Down

0 comments on commit d841e40

Please sign in to comment.