From ec8843ff2798576bb0fd1e2ae4bf70eda9a838c2 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Fri, 20 Oct 2023 13:44:32 -0700 Subject: [PATCH] fix printing --- llm_analysis/analysis.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index a5cc496..2262f00 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -180,6 +180,7 @@ def __init__( ) self.total_num_params = self.get_num_params_total() + self.total_num_active_params = self.get_num_active_params_total() def update_model_config(self, model_config: ModelConfig) -> None: self.model_config = model_config @@ -355,6 +356,18 @@ def get_num_params_total(self) -> int: + self.get_num_params_embedding() + self.get_num_params_per_layer_layernorm() ) + def get_num_active_params_total(self) -> int: + """Get the total number of parameters in the model, including all the + transformer layers and the embedding layer. + + Returns: + int: the total number of parameters in the model + """ + return ( + self.model_config.num_layers * self.get_num_active_params_per_layer() + + self.get_num_params_embedding() + self.get_num_params_per_layer_layernorm() + ) + def get_memory_weight_per_layer( self, ds_zero: DSZeRO = DSZeRO.NONE, return_breakdown: bool = False ) -> float: @@ -1319,7 +1332,7 @@ def get_readable_summary_dict( ) -> str: log_str = f"\n{title.center(PRINT_LINE_WIDTH, '-')}\n" for key, value in summary_dict.items(): - if "num_tokens" in key or "num_params" in key or "flops" in key: + if "num_tokens" in key or "num_params" in key or "num_active_params" in key or "flops" in key: log_str += f"{key}: {_num_to_string(value)}\n" elif "gpu_hours" == key: log_str += f"{key}: {int(value)}\n" if value else "" @@ -1411,7 +1424,7 @@ def inference( "num_layers not be divisible by pp_size, taking the floor" ) - embedding_memory_per_gpu = self.get_memory_embedding( + weight_memory_embedding_per_gpu = self.get_memory_embedding( dtype_bytes=self.dtype_config.embedding_bits / BITS_PER_BYTE ) @@ -1419,7 +1432,7 @@ def inference( weight_memory_per_gpu = ( weight_memory_layers_per_gpu - + embedding_memory_per_gpu + + weight_memory_embedding_per_gpu ) memory_left = ( @@ -1434,7 +1447,7 @@ def inference( logger.info( f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B" " (embedding + attn + mlp + layernorm:" - f" {_num_to_string(embedding_memory_per_gpu)}B + {_num_to_string(weight_memory_attn_per_gpu)}B + {_num_to_string(weight_memory_mlp_per_gpu)}B + {_num_to_string(weight_memory_layernorm_per_gpu)}B), memory_left:" + f" {_num_to_string(weight_memory_embedding_per_gpu)}B + {_num_to_string(weight_memory_attn_per_gpu)}B + {_num_to_string(weight_memory_mlp_per_gpu)}B + {_num_to_string(weight_memory_layernorm_per_gpu)}B), memory_left:" f" {_num_to_string(memory_left)}B" ) @@ -1643,7 +1656,7 @@ def inference( "kv_cache_latency": kv_cache_latency, "kv_cache_memory_per_gpu": kv_cache_memory_per_gpu, "weight_memory_per_gpu": weight_memory_per_gpu, - "embedding_memory_per_gpu": embedding_memory_per_gpu, + "weight_memory_embedding_per_gpu": weight_memory_embedding_per_gpu, "prefill_activation_memory_per_gpu": prefill_activation_memory_per_gpu, "prefill_max_batch_size_per_gpu": prefill_max_batch_size_per_gpu, "prefill_activation_memory_per_gpu": prefill_activation_memory_per_gpu, @@ -1867,7 +1880,7 @@ def training( "num_layers not be divisible by pp_size, taking the floor" ) - embedding_memory_per_gpu = self.get_memory_embedding( + weight_memory_embedding_per_gpu = self.get_memory_embedding( dtype_bytes=self.dtype_config.embedding_bits / BITS_PER_BYTE ) @@ -1876,7 +1889,7 @@ def training( weight_memory_per_gpu = ( weight_memory_layers_per_gpu - + embedding_memory_per_gpu + + weight_memory_embedding_per_gpu ) optimizer_state_memory_per_gpu = ( @@ -1898,7 +1911,7 @@ def training( logger.info( f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B" " (embedding_memory:" - f" {_num_to_string(embedding_memory_per_gpu)}B)," + f" {_num_to_string(weight_memory_embedding_per_gpu)}B)," " \noptimizer_state_memory_per_gpu:" f" {_num_to_string(optimizer_state_memory_per_gpu)}B," " gradient_memory_per_gpu:" @@ -1971,6 +1984,7 @@ def training( f" {_num_to_string(memory_left)}B, max_batch_size_per_gpu =" f" {max_batch_size_per_gpu})" ) + memory_left -= activation_memory_per_gpu num_flops_fwd_total = self.get_num_flops_fwd_total( batch_size_per_gpu, seq_len @@ -2102,6 +2116,7 @@ def training( "seq_len": seq_len, "total_num_tokens": total_num_tokens, "num_params_total": self.total_num_params, + "num_active_params_total": self.total_num_active_params, "activation_recomputation": ActivationRecomputation( activation_recomputation ).name, @@ -2113,7 +2128,7 @@ def training( "hbm_memory_efficiency": self.hbm_memory_efficiency, "num_flops_total_per_micro_batch": num_flops_total_per_micro_batch, "weight_memory_per_gpu": weight_memory_per_gpu, - "embedding_memory_per_gpu": embedding_memory_per_gpu, + "weight_memory_embedding_per_gpu": weight_memory_embedding_per_gpu, "weight_memory_attn_per_gpu": weight_memory_attn_per_gpu, "weight_memory_mlp_per_gpu": weight_memory_mlp_per_gpu, "weight_memory_layernorm_per_gpu": weight_memory_layernorm_per_gpu, @@ -2124,6 +2139,7 @@ def training( "attn_activation_memory_per_gpu": attn_activation_memory, "mlp_activation_memory_per_gpu": mlp_activation_memory, "layernorm_activation_memory_per_gpu": layernorm_activation_memory, + "memory_left": memory_left, "latency_per_micro_batch": latency_per_micro_batch, "latency_fwd": latency_fwd, }