diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 5354d3d..4bc2cdf 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -376,9 +376,9 @@ def get_memory_weight_per_layer( memory_weight_attn_per_layer = self.get_num_params_per_layer_attn() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size /sharded_dp_size - memory_weight_mlp_per_layer = (self.get_num_params_per_layer_mlp() / self.parallelism_config.ep_size + self.get_num_params_per_layer_router()) * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size + memory_weight_mlp_per_layer = (self.get_num_params_per_layer_mlp() / self.parallelism_config.ep_size + self.get_num_params_per_layer_router()) * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size / (sharded_dp_size/self.parallelism_config.ep_size) - memory_weight_layernorm_per_layer = self.get_num_params_per_layer_layernorm() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size + memory_weight_layernorm_per_layer = self.get_num_params_per_layer_layernorm() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size / sharded_dp_size memory_weight_per_layer = memory_weight_attn_per_layer + memory_weight_mlp_per_layer + memory_weight_layernorm_per_layer @@ -580,7 +580,6 @@ def get_memory_activation_per_layer_mlp( seq_len * batch_size * hidden_dim / sp_size ) if with_dropout else 0 - print(f'XXXX recompute_gelu = {recompute_gelu}') if self.model_config.moe_num_experts == 1: memory_activation_per_layer_mlp = ( (1 * seq_len * batch_size * hidden_dim / sp_size) @@ -594,14 +593,14 @@ def get_memory_activation_per_layer_mlp( return memory_activation_per_layer_mlp - def get_memory_activation_per_layer_layernorm( + def get_memory_activation_per_layernorm( self, batch_size: int, seq_len: int, activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE, dtype_bytes: int = BYTES_FP32, ) -> float: - """Get the memory (in bytes) required to store the activations of a + """Get the memory (in bytes) required to store the activations of a single layernorm in a transformer layer, given the batch size, sequence length. Refer to https://arxiv.org/abs/2205.05198 for details. @@ -674,8 +673,8 @@ def get_memory_activation_per_layer( ) ) - memory_activation_per_layer_layernorm = ( - self.get_memory_activation_per_layer_layernorm( + memory_activation_per_layernorm = ( + self.get_memory_activation_per_layernorm( batch_size, seq_len, activation_recomputation, @@ -684,26 +683,26 @@ def get_memory_activation_per_layer( ) if is_inference: - memory_activation_per_layer = max(memory_activation_per_layer_attn, memory_activation_per_layer_mlp, memory_activation_per_layer_layernorm) + memory_activation_per_layer = max(memory_activation_per_layer_attn, memory_activation_per_layer_mlp, memory_activation_per_layernorm) logger.debug( f"memory_activation_per_layer for batch_size {batch_size}:" f" {_num_to_string(memory_activation_per_layer)}B" f" (max(attn, mlp, layernorm): max({_num_to_string(memory_activation_per_layer_attn)}B ," f" {_num_to_string(memory_activation_per_layer_mlp)}B , 2 *" - f" {_num_to_string(2*memory_activation_per_layer_layernorm)}B))" + f" {_num_to_string(2*memory_activation_per_layernorm)}B))" ) else: memory_activation_per_layer = ( memory_activation_per_layer_attn + memory_activation_per_layer_mlp - + 2 * memory_activation_per_layer_layernorm + + 2 * memory_activation_per_layernorm ) logger.debug( f"memory_activation_per_layer for batch_size {batch_size}:" f" {_num_to_string(memory_activation_per_layer)}B" f" (attn + mlp + layernorm: {_num_to_string(memory_activation_per_layer_attn)}B +" f" {_num_to_string(memory_activation_per_layer_mlp)}B + 2 *" - f" {_num_to_string(2*memory_activation_per_layer_layernorm)}B)" + f" {_num_to_string(memory_activation_per_layernorm)}B)" ) return memory_activation_per_layer @@ -1021,7 +1020,7 @@ def get_latency_fwd_per_layer_layernorm( Returns: float: the latency in seconds for the forward pass of a single layernorm in a transformer layer """ - activation_memory = self.get_memory_activation_per_layer_layernorm( + activation_memory = self.get_memory_activation_per_layernorm( batch_size, seq_len, ) @@ -2088,6 +2087,9 @@ def training( "activation_recomputation": ActivationRecomputation( activation_recomputation ).name, + "layernorm_dtype_bytes": layernorm_dtype_bytes, + "mlp_activation_quant_bits": mlp_activation_quant_bits, + "mlp_recompute_gelu": mlp_recompute_gelu, "achieved_flops": self.get_TFLOPS_per_gpu(), "flops_efficiency": self.flops_efficiency, "hbm_memory_efficiency": self.hbm_memory_efficiency, diff --git a/tests/test_training.py b/tests/test_training.py index e2e310a..a38d406 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -70,7 +70,7 @@ def test_training_megatron_lm_1(): == "84.82 days" ) - assert _num_to_string(summary_dict["num_params_total"]) == "174.56 G" + assert _num_to_string(summary_dict["num_params_total"]) == "174.57 G" # megatron-lm paper https://arxiv.org/abs/2104.04473 Table 2