Skip to content

Commit

Permalink
fix printing
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Oct 20, 2023
1 parent 54b51aa commit ec8843f
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(
)

self.total_num_params = self.get_num_params_total()
self.total_num_active_params = self.get_num_active_params_total()

def update_model_config(self, model_config: ModelConfig) -> None:
self.model_config = model_config
Expand Down Expand Up @@ -355,6 +356,18 @@ def get_num_params_total(self) -> int:
+ self.get_num_params_embedding() + self.get_num_params_per_layer_layernorm()
)

def get_num_active_params_total(self) -> int:
"""Get the total number of parameters in the model, including all the
transformer layers and the embedding layer.
Returns:
int: the total number of parameters in the model
"""
return (
self.model_config.num_layers * self.get_num_active_params_per_layer()
+ self.get_num_params_embedding() + self.get_num_params_per_layer_layernorm()
)

def get_memory_weight_per_layer(
self, ds_zero: DSZeRO = DSZeRO.NONE, return_breakdown: bool = False
) -> float:
Expand Down Expand Up @@ -1319,7 +1332,7 @@ def get_readable_summary_dict(
) -> str:
log_str = f"\n{title.center(PRINT_LINE_WIDTH, '-')}\n"
for key, value in summary_dict.items():
if "num_tokens" in key or "num_params" in key or "flops" in key:
if "num_tokens" in key or "num_params" in key or "num_active_params" in key or "flops" in key:
log_str += f"{key}: {_num_to_string(value)}\n"
elif "gpu_hours" == key:
log_str += f"{key}: {int(value)}\n" if value else ""
Expand Down Expand Up @@ -1411,15 +1424,15 @@ def inference(
"num_layers not be divisible by pp_size, taking the floor"
)

embedding_memory_per_gpu = self.get_memory_embedding(
weight_memory_embedding_per_gpu = self.get_memory_embedding(
dtype_bytes=self.dtype_config.embedding_bits / BITS_PER_BYTE
)

weight_memory_layers_per_gpu, weight_memory_attn_per_gpu, weight_memory_mlp_per_gpu, weight_memory_layernorm_per_gpu= [x*self.model_config.num_layers for x in self.get_memory_weight_per_layer(ds_zero, return_breakdown=True)]

weight_memory_per_gpu = (
weight_memory_layers_per_gpu
+ embedding_memory_per_gpu
+ weight_memory_embedding_per_gpu
)

memory_left = (
Expand All @@ -1434,7 +1447,7 @@ def inference(
logger.info(
f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B"
" (embedding + attn + mlp + layernorm:"
f" {_num_to_string(embedding_memory_per_gpu)}B + {_num_to_string(weight_memory_attn_per_gpu)}B + {_num_to_string(weight_memory_mlp_per_gpu)}B + {_num_to_string(weight_memory_layernorm_per_gpu)}B), memory_left:"
f" {_num_to_string(weight_memory_embedding_per_gpu)}B + {_num_to_string(weight_memory_attn_per_gpu)}B + {_num_to_string(weight_memory_mlp_per_gpu)}B + {_num_to_string(weight_memory_layernorm_per_gpu)}B), memory_left:"
f" {_num_to_string(memory_left)}B"
)

Expand Down Expand Up @@ -1643,7 +1656,7 @@ def inference(
"kv_cache_latency": kv_cache_latency,
"kv_cache_memory_per_gpu": kv_cache_memory_per_gpu,
"weight_memory_per_gpu": weight_memory_per_gpu,
"embedding_memory_per_gpu": embedding_memory_per_gpu,
"weight_memory_embedding_per_gpu": weight_memory_embedding_per_gpu,
"prefill_activation_memory_per_gpu": prefill_activation_memory_per_gpu,
"prefill_max_batch_size_per_gpu": prefill_max_batch_size_per_gpu,
"prefill_activation_memory_per_gpu": prefill_activation_memory_per_gpu,
Expand Down Expand Up @@ -1867,7 +1880,7 @@ def training(
"num_layers not be divisible by pp_size, taking the floor"
)

embedding_memory_per_gpu = self.get_memory_embedding(
weight_memory_embedding_per_gpu = self.get_memory_embedding(
dtype_bytes=self.dtype_config.embedding_bits / BITS_PER_BYTE
)

Expand All @@ -1876,7 +1889,7 @@ def training(

weight_memory_per_gpu = (
weight_memory_layers_per_gpu
+ embedding_memory_per_gpu
+ weight_memory_embedding_per_gpu
)

optimizer_state_memory_per_gpu = (
Expand All @@ -1898,7 +1911,7 @@ def training(
logger.info(
f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B"
" (embedding_memory:"
f" {_num_to_string(embedding_memory_per_gpu)}B),"
f" {_num_to_string(weight_memory_embedding_per_gpu)}B),"
" \noptimizer_state_memory_per_gpu:"
f" {_num_to_string(optimizer_state_memory_per_gpu)}B,"
" gradient_memory_per_gpu:"
Expand Down Expand Up @@ -1971,6 +1984,7 @@ def training(
f" {_num_to_string(memory_left)}B, max_batch_size_per_gpu ="
f" {max_batch_size_per_gpu})"
)
memory_left -= activation_memory_per_gpu

num_flops_fwd_total = self.get_num_flops_fwd_total(
batch_size_per_gpu, seq_len
Expand Down Expand Up @@ -2102,6 +2116,7 @@ def training(
"seq_len": seq_len,
"total_num_tokens": total_num_tokens,
"num_params_total": self.total_num_params,
"num_active_params_total": self.total_num_active_params,
"activation_recomputation": ActivationRecomputation(
activation_recomputation
).name,
Expand All @@ -2113,7 +2128,7 @@ def training(
"hbm_memory_efficiency": self.hbm_memory_efficiency,
"num_flops_total_per_micro_batch": num_flops_total_per_micro_batch,
"weight_memory_per_gpu": weight_memory_per_gpu,
"embedding_memory_per_gpu": embedding_memory_per_gpu,
"weight_memory_embedding_per_gpu": weight_memory_embedding_per_gpu,
"weight_memory_attn_per_gpu": weight_memory_attn_per_gpu,
"weight_memory_mlp_per_gpu": weight_memory_mlp_per_gpu,
"weight_memory_layernorm_per_gpu": weight_memory_layernorm_per_gpu,
Expand All @@ -2124,6 +2139,7 @@ def training(
"attn_activation_memory_per_gpu": attn_activation_memory,
"mlp_activation_memory_per_gpu": mlp_activation_memory,
"layernorm_activation_memory_per_gpu": layernorm_activation_memory,
"memory_left": memory_left,
"latency_per_micro_batch": latency_per_micro_batch,
"latency_fwd": latency_fwd,
}
Expand Down

0 comments on commit ec8843f

Please sign in to comment.