Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Oct 19, 2023
1 parent 6bad7b4 commit 4027709
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
26 changes: 14 additions & 12 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def get_memory_weight_per_layer(

memory_weight_attn_per_layer = self.get_num_params_per_layer_attn() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size /sharded_dp_size

memory_weight_mlp_per_layer = (self.get_num_params_per_layer_mlp() / self.parallelism_config.ep_size + self.get_num_params_per_layer_router()) * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size
memory_weight_mlp_per_layer = (self.get_num_params_per_layer_mlp() / self.parallelism_config.ep_size + self.get_num_params_per_layer_router()) * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size / (sharded_dp_size/self.parallelism_config.ep_size)

memory_weight_layernorm_per_layer = self.get_num_params_per_layer_layernorm() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size
memory_weight_layernorm_per_layer = self.get_num_params_per_layer_layernorm() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size / sharded_dp_size

memory_weight_per_layer = memory_weight_attn_per_layer + memory_weight_mlp_per_layer + memory_weight_layernorm_per_layer

Expand Down Expand Up @@ -580,7 +580,6 @@ def get_memory_activation_per_layer_mlp(
seq_len * batch_size * hidden_dim / sp_size
) if with_dropout else 0

print(f'XXXX recompute_gelu = {recompute_gelu}')
if self.model_config.moe_num_experts == 1:
memory_activation_per_layer_mlp = (
(1 * seq_len * batch_size * hidden_dim / sp_size)
Expand All @@ -594,14 +593,14 @@ def get_memory_activation_per_layer_mlp(

return memory_activation_per_layer_mlp

def get_memory_activation_per_layer_layernorm(
def get_memory_activation_per_layernorm(
self,
batch_size: int,
seq_len: int,
activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE,
dtype_bytes: int = BYTES_FP32,
) -> float:
"""Get the memory (in bytes) required to store the activations of a
"""Get the memory (in bytes) required to store the activations of a
single layernorm in a transformer layer, given the batch size, sequence
length. Refer to https://arxiv.org/abs/2205.05198 for details.
Expand Down Expand Up @@ -674,8 +673,8 @@ def get_memory_activation_per_layer(
)
)

memory_activation_per_layer_layernorm = (
self.get_memory_activation_per_layer_layernorm(
memory_activation_per_layernorm = (
self.get_memory_activation_per_layernorm(
batch_size,
seq_len,
activation_recomputation,
Expand All @@ -684,26 +683,26 @@ def get_memory_activation_per_layer(
)

if is_inference:
memory_activation_per_layer = max(memory_activation_per_layer_attn, memory_activation_per_layer_mlp, memory_activation_per_layer_layernorm)
memory_activation_per_layer = max(memory_activation_per_layer_attn, memory_activation_per_layer_mlp, memory_activation_per_layernorm)
logger.debug(
f"memory_activation_per_layer for batch_size {batch_size}:"
f" {_num_to_string(memory_activation_per_layer)}B"
f" (max(attn, mlp, layernorm): max({_num_to_string(memory_activation_per_layer_attn)}B ,"
f" {_num_to_string(memory_activation_per_layer_mlp)}B , 2 *"
f" {_num_to_string(2*memory_activation_per_layer_layernorm)}B))"
f" {_num_to_string(2*memory_activation_per_layernorm)}B))"
)
else:
memory_activation_per_layer = (
memory_activation_per_layer_attn
+ memory_activation_per_layer_mlp
+ 2 * memory_activation_per_layer_layernorm
+ 2 * memory_activation_per_layernorm
)
logger.debug(
f"memory_activation_per_layer for batch_size {batch_size}:"
f" {_num_to_string(memory_activation_per_layer)}B"
f" (attn + mlp + layernorm: {_num_to_string(memory_activation_per_layer_attn)}B +"
f" {_num_to_string(memory_activation_per_layer_mlp)}B + 2 *"
f" {_num_to_string(2*memory_activation_per_layer_layernorm)}B)"
f" {_num_to_string(memory_activation_per_layernorm)}B)"
)

return memory_activation_per_layer
Expand Down Expand Up @@ -1021,7 +1020,7 @@ def get_latency_fwd_per_layer_layernorm(
Returns:
float: the latency in seconds for the forward pass of a single layernorm in a transformer layer
"""
activation_memory = self.get_memory_activation_per_layer_layernorm(
activation_memory = self.get_memory_activation_per_layernorm(
batch_size,
seq_len,
)
Expand Down Expand Up @@ -2088,6 +2087,9 @@ def training(
"activation_recomputation": ActivationRecomputation(
activation_recomputation
).name,
"layernorm_dtype_bytes": layernorm_dtype_bytes,
"mlp_activation_quant_bits": mlp_activation_quant_bits,
"mlp_recompute_gelu": mlp_recompute_gelu,
"achieved_flops": self.get_TFLOPS_per_gpu(),
"flops_efficiency": self.flops_efficiency,
"hbm_memory_efficiency": self.hbm_memory_efficiency,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_training_megatron_lm_1():
== "84.82 days"
)

assert _num_to_string(summary_dict["num_params_total"]) == "174.56 G"
assert _num_to_string(summary_dict["num_params_total"]) == "174.57 G"


# megatron-lm paper https://arxiv.org/abs/2104.04473 Table 2
Expand Down

0 comments on commit 4027709

Please sign in to comment.