Skip to content

Commit

Permalink
add memory usage for embedding and last last layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Nov 10, 2023
1 parent 0f1d957 commit f0b6c84
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 65 deletions.
215 changes: 151 additions & 64 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def get_num_params_per_layer_router(self) -> int:
def get_num_params_per_layer_layernorm(self) -> int:
return 2 * self.model_config.hidden_dim

def get_num_params_last_layernorm(self) -> int:
return self.model_config.hidden_dim

def get_num_params_per_layer(self) -> int:
"""Get the number of parameters in a transformer layer, including the attention
and MLP linear layers.
Expand Down Expand Up @@ -331,7 +334,7 @@ def get_num_params_total(self) -> int:
return (
self.model_config.num_layers * self.get_num_params_per_layer() +
self.get_num_params_embedding() +
self.get_num_params_per_layer_layernorm())
self.get_num_params_last_layernorm())

def get_num_active_params_total(self) -> int:
"""Get the total number of parameters in the model, including all the
Expand All @@ -343,7 +346,7 @@ def get_num_active_params_total(self) -> int:
return (self.model_config.num_layers *
self.get_num_active_params_per_layer() +
self.get_num_params_embedding() +
self.get_num_params_per_layer_layernorm())
self.get_num_params_last_layernorm())

def get_memory_weight_per_layer(
self,
Expand Down Expand Up @@ -389,6 +392,36 @@ def get_memory_weight_per_layer(
return memory_weight_per_layer, memory_weight_attn_per_layer, memory_weight_mlp_per_layer, memory_weight_layernorm_per_layer
return memory_weight_per_layer

def get_memory_weight_last_layernorm(self, ds_zero: DSZeRO = DSZeRO.NONE):
memory_weight_last_layernorm = self.get_num_params_last_layernorm(
) * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size
if ds_zero == DSZeRO.STAGE_3:
memory_weight_last_layernorm /= self.parallelism_config.dp_size
return memory_weight_last_layernorm

def get_memory_optimizer_state_and_gradient_embedding(
self,
master_weights_dtype_bytes: int = BYTES_FP32,
other_op_bytes: int = None,
ds_zero: DSZeRO = DSZeRO.NONE,
) -> tuple:
if other_op_bytes is None:
op_bytes_per_params = BYTES_FP32 + 2 * BYTES_FP32 # adam optimizer
else:
op_bytes_per_params = (other_op_bytes + master_weights_dtype_bytes)

memory_optimizer_state_embedding = op_bytes_per_params * self.get_num_params_embedding(
) / self.parallelism_config.tp_size
if ds_zero >= DSZeRO.STAGE_1:
memory_optimizer_state_embedding /= self.parallelism_config.dp_size

memory_gradient_embedding = master_weights_dtype_bytes * self.get_num_params_embedding(
) / self.parallelism_config.tp_size
if ds_zero >= DSZeRO.STAGE_2:
memory_gradient_embedding /= self.parallelism_config.dp_size

return memory_optimizer_state_embedding, memory_gradient_embedding

def get_memory_optimizer_state_and_gradient_per_layer(
self,
master_weights_dtype_bytes: int = BYTES_FP32,
Expand Down Expand Up @@ -459,20 +492,51 @@ def get_memory_optimizer_state_and_gradient_per_layer(

return memory_optimizer_state_per_layer, memory_gradient_per_layer

def get_memory_embedding(self, dtype_bytes: int = BYTES_FP32) -> float:
def get_memory_optimizer_state_and_gradient_last_layernorm(
self,
master_weights_dtype_bytes: int = BYTES_FP32,
other_op_bytes: int = None,
ds_zero: DSZeRO = DSZeRO.NONE,
) -> tuple:
if other_op_bytes is None:
op_bytes_per_params = BYTES_FP32 + 2 * BYTES_FP32 # adam optimizer
else:
op_bytes_per_params = (other_op_bytes + master_weights_dtype_bytes)

memory_optimizer_state_last_layernorm = op_bytes_per_params * self.get_num_params_last_layernorm(
) / self.parallelism_config.tp_size
if ds_zero >= DSZeRO.STAGE_1:
memory_optimizer_state_last_layernorm /= self.parallelism_config.dp_size

memory_gradient_last_layernorm = master_weights_dtype_bytes * self.get_num_params_last_layernorm(
) / self.parallelism_config.tp_size
if ds_zero >= DSZeRO.STAGE_2:
memory_gradient_last_layernorm /= self.parallelism_config.dp_size

return memory_optimizer_state_last_layernorm, memory_gradient_last_layernorm

def get_memory_embedding(
self,
ds_zero: DSZeRO = DSZeRO.NONE,
) -> float:
"""Get the memory (in bytes) required to store the embedding layer, given the
number of parameters in the embedding layer, the data type (defaults to FP32)
used for the weights, and the tensor parallelism size (Megatron-LM partitions
the embedding layer across the tensor parallel groups).
Args:
dtype_bytes (int, optional): the number of bytes in the data type for embedding weight. Defaults to BYTES_FP32.
ds_zero (DSZeRO, optional): which DeepSpeed ZeRO stage to use. Defaults to DSZeRO.NONE (disabled, no sharding).
Returns:
float: the memory (in bytes) required to store the embedding layer
"""
return (self.get_num_params_embedding() /
self.parallelism_config.tp_size) * dtype_bytes
dtype_bytes = self.dtype_config.embedding_bits / BITS_PER_BYTE
memory_embedding = (self.get_num_params_embedding() /
self.parallelism_config.tp_size) * dtype_bytes
if ds_zero == DSZeRO.STAGE_3:
memory_embedding /= self.parallelism_config.dp_size

return memory_embedding

def get_memory_activation_per_layer_attn(
self,
Expand Down Expand Up @@ -746,13 +810,12 @@ def get_memory_activation_per_layer(
gated_linear_units=mlp_gated_linear_units,
))

memory_activation_per_layernorm = (
self.get_memory_activation_per_layernorm(
batch_size,
seq_len,
activation_recomputation,
layernorm_dtype_bytes,
))
memory_activation_per_layernorm = self.get_memory_activation_per_layernorm(
batch_size,
seq_len,
activation_recomputation,
layernorm_dtype_bytes,
)

if is_inference:
memory_activation_per_layer = max(memory_activation_per_layer_attn,
Expand Down Expand Up @@ -1406,17 +1469,17 @@ def inference(
logger.info(
"num_layers not be divisible by pp_size, taking the floor")

weight_memory_embedding_per_gpu = self.get_memory_embedding(
dtype_bytes=self.dtype_config.embedding_bits / BITS_PER_BYTE)

weight_memory_embedding_per_gpu = self.get_memory_embedding(ds_zero)
weight_memory_layers_per_gpu, weight_memory_attn_per_gpu, weight_memory_mlp_per_gpu, weight_memory_layernorm_per_gpu = [
x * self.model_config.num_layers
for x in self.get_memory_weight_per_layer(ds_zero,
return_breakdown=True)
]

memory_weight_last_layernorm = self.get_memory_weight_last_layernorm(
ds_zero)
weight_memory_per_gpu = (weight_memory_layers_per_gpu +
weight_memory_embedding_per_gpu)
weight_memory_embedding_per_gpu +
memory_weight_last_layernorm)

memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 -
weight_memory_per_gpu)
Expand All @@ -1432,36 +1495,47 @@ def inference(
f" {_num_to_string(memory_left)}B")

# With pipeline parallelism, each stage contains L/p layers so the first stage must store p ×L/p = L layers worth of activations regardless of the pipeline parallel size p; activation memory required for the input embeddings, the last layer-norm, and the output layer are ignored here. Refer to https://arxiv.org/abs/2205.05198 for more details.
prefill_activation_memory_batch_size_1 = (
self.get_memory_activation_per_layer(
1,
seq_len,
is_inference=True,
layernorm_dtype_bytes=layernorm_dtype_bytes,
))
prefill_memory_activation_per_layer_batch_size_1 = self.get_memory_activation_per_layer(
1,
seq_len,
is_inference=True,
layernorm_dtype_bytes=layernorm_dtype_bytes,
)
prefill_memory_activation_embedding_output_batch_size_1 = self.get_memory_activation_embedding_output(
1, seq_len)

prefill_activation_memory_batch_size_1 = max(
prefill_memory_activation_per_layer_batch_size_1,
prefill_memory_activation_embedding_output_batch_size_1)

prefill_max_batch_size_per_gpu = int(
memory_left / prefill_activation_memory_batch_size_1)
logger.info(
f"prefill_activation_memory_batch_size_1: {_num_to_string(prefill_activation_memory_batch_size_1)}B,"
" prefill_max_batch_size_per_gpu:"
f" {prefill_max_batch_size_per_gpu}")

prefill_activation_memory_per_gpu = (
self.get_memory_activation_per_layer(
batch_size_per_gpu,
seq_len,
is_inference=True,
layernorm_dtype_bytes=layernorm_dtype_bytes,
))
prefill_memory_activation_per_layer = self.get_memory_activation_per_layer(
batch_size_per_gpu,
seq_len,
is_inference=True,
layernorm_dtype_bytes=layernorm_dtype_bytes,
)
prefill_memory_activation_embedding_output = self.get_memory_activation_embedding_output(
batch_size_per_gpu, seq_len)
prefill_activation_memory_per_gpu = max(
prefill_memory_activation_per_layer,
prefill_memory_activation_embedding_output)

logger.info("prefill_activation_memory_per_gpu with batch_size_per_gpu"
f" {batch_size_per_gpu}:"
f" {_num_to_string(prefill_activation_memory_per_gpu)}B")
assert memory_left > prefill_activation_memory_per_gpu, (
"activation memory is too large with batch_size_per_gpu ="
"prefill activation memory is too large with batch_size_per_gpu ="
f" {batch_size_per_gpu} to fit in GPU memory(requiring"
f" {_num_to_string(prefill_activation_memory_per_gpu)}B),"
" memory_left after fitting in model weights:"
f" {_num_to_string(memory_left)}B, max_batch_size_per_gpu:"
f" {_num_to_string(memory_left)}B, prefill_max_batch_size_per_gpu:"
f" {prefill_max_batch_size_per_gpu}")

prefill_num_flops_fwd_total = self.get_num_flops_fwd_total(
Expand All @@ -1487,25 +1561,27 @@ def inference(
f" ({batch_size_per_gpu * (seq_len+num_tokens_to_generate)}) is larger"
f" than ({round(self.get_pivot(), 3)}), which is the pivot"
" point")
kv_cache_memory_per_gpu = (self.get_memory_kv_cache_per_layer(
kv_cache_memory_per_gpu = self.get_memory_kv_cache_per_layer(
batch_size_per_gpu,
seq_len + num_tokens_to_generate,
kv_cache_dtype_bytes=kv_cache_dtype_bytes,
) * num_layers_per_gpu)
) * num_layers_per_gpu

# load and store kv cache
kv_cache_latency = (2 * kv_cache_memory_per_gpu /
(self.get_gpu_hbm_bandwidth() * 10**9))

decode_activation_memory_per_layer = (
self.get_memory_activation_per_layer(
batch_size_per_gpu,
1,
is_inference=True,
layernorm_dtype_bytes=layernorm_dtype_bytes,
))
decode_activation_memory_per_gpu = (
decode_activation_memory_per_layer)
decode_activation_memory_per_layer = self.get_memory_activation_per_layer(
batch_size_per_gpu,
1,
is_inference=True,
layernorm_dtype_bytes=layernorm_dtype_bytes,
)
decode_memory_activation_embedding_output = self.get_memory_activation_embedding_output(
batch_size_per_gpu, 1)
decode_activation_memory_per_gpu = max(
decode_activation_memory_per_layer,
decode_memory_activation_embedding_output)

logger.info(
"kv_cache_memory_per_gpu:"
Expand All @@ -1528,17 +1604,10 @@ def inference(
" decode_max_batch_size_per_gpu:"
f" {decode_max_batch_size_per_gpu}")
else:
decode_activation_memory_batch_size_1 = (
self.get_memory_activation_per_layer(
1,
seq_len + num_tokens_to_generate,
is_inference=True,
layernorm_dtype_bytes=layernorm_dtype_bytes,
))
decode_max_batch_size_per_gpu = int(
memory_left / decode_activation_memory_batch_size_1)
memory_left / prefill_activation_memory_batch_size_1)
logger.info("decode_activation_memory_batch_size_1:"
f" {decode_activation_memory_batch_size_1},"
f" {prefill_activation_memory_batch_size_1},"
" decode_max_batch_size_per_gpu:"
f" {decode_max_batch_size_per_gpu}")

Expand Down Expand Up @@ -1604,8 +1673,6 @@ def inference(
"prefill_activation_memory_per_gpu":
prefill_activation_memory_per_gpu,
"prefill_max_batch_size_per_gpu": prefill_max_batch_size_per_gpu,
"prefill_activation_memory_per_gpu":
prefill_activation_memory_per_gpu,
"prefill_num_flops_fwd_total": prefill_num_flops_fwd_total,
"decode_activation_memory_per_gpu":
decode_activation_memory_per_gpu,
Expand Down Expand Up @@ -1714,7 +1781,7 @@ def config_batch_size_and_gradient_accumulation_steps(
# batch_size_per_gpu is not None
assert (
batch_size_per_gpu <= max_batch_size_per_gpu
), f"batch_size_per_gpu must be <= max_batch_size_per_gpu, {assert_msg}"
), f"batch_size_per_gpu {batch_size_per_gpu} must be <= max_batch_size_per_gpu {max_batch_size_per_gpu}, {assert_msg}"
if gradient_accumulation_steps is None:
gradient_accumulation_steps = 1
global_batch_size = (batch_size_per_gpu *
Expand Down Expand Up @@ -1835,22 +1902,30 @@ def training(
logger.info(
"num_layers not be divisible by pp_size, taking the floor")

weight_memory_embedding_per_gpu = self.get_memory_embedding(
dtype_bytes=self.dtype_config.embedding_bits / BITS_PER_BYTE)
weight_memory_embedding_per_gpu = self.get_memory_embedding(ds_zero)

weight_memory_layers_per_gpu, weight_memory_attn_per_gpu, weight_memory_mlp_per_gpu, weight_memory_layernorm_per_gpu = [
x * num_layers_per_gpu
for x in self.get_memory_weight_per_layer(ds_zero,
return_breakdown=True)
]

weight_memory_per_gpu = (weight_memory_layers_per_gpu +
weight_memory_embedding_per_gpu)
weight_memory_last_layernorm = self.get_memory_weight_last_layernorm(
ds_zero)
weight_memory_per_gpu = (weight_memory_embedding_per_gpu +
weight_memory_layers_per_gpu +
weight_memory_last_layernorm)

optimizer_state_memory_per_layer, gradient_memory_per_layer = self.get_memory_optimizer_state_and_gradient_per_layer(
master_weights_dtype_bytes, other_op_bytes, ds_zero)
optimizer_state_memory_per_gpu = optimizer_state_memory_per_layer * num_layers_per_gpu
gradient_memory_per_gpu = gradient_memory_per_layer * num_layers_per_gpu

optimizer_state_memory_embedding, gradient_memory_embedding = self.get_memory_optimizer_state_and_gradient_embedding(
master_weights_dtype_bytes, other_op_bytes, ds_zero)

optimizer_state_memory_last_layernorm, gradient_memory_last_layernorm = self.get_memory_optimizer_state_and_gradient_last_layernorm(
master_weights_dtype_bytes, other_op_bytes, ds_zero)

optimizer_state_memory_per_gpu = optimizer_state_memory_per_layer * num_layers_per_gpu + optimizer_state_memory_embedding + optimizer_state_memory_last_layernorm
gradient_memory_per_gpu = gradient_memory_per_layer * num_layers_per_gpu + gradient_memory_embedding + gradient_memory_last_layernorm

memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 -
weight_memory_per_gpu - optimizer_state_memory_per_gpu -
Expand Down Expand Up @@ -1897,6 +1972,12 @@ def training(
activation_memory_embedding_output_per_gpu = self.get_memory_activation_embedding_output(
1, seq_len)
activation_memory_batch_size_1 += activation_memory_embedding_output_per_gpu
activation_memory_batch_size_1 += self.get_memory_activation_per_layernorm(
1,
seq_len,
activation_recomputation,
layernorm_dtype_bytes,
)

max_batch_size_per_gpu = int(memory_left //
activation_memory_batch_size_1)
Expand Down Expand Up @@ -1942,6 +2023,12 @@ def training(
activation_memory_embedding_output_per_gpu = self.get_memory_activation_embedding_output(
batch_size_per_gpu, seq_len)
activation_memory_per_gpu += activation_memory_embedding_output_per_gpu
activation_memory_per_gpu += self.get_memory_activation_per_layernorm(
batch_size_per_gpu,
seq_len,
activation_recomputation,
layernorm_dtype_bytes,
)

logger.info("activation_memory_per_gpu with batch_size_per_gpu"
f" {batch_size_per_gpu}:"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_training_megatron_lm_3():
pp_size = 35
total_num_gpus = 560
dp_size = total_num_gpus // (tp_size * pp_size)
batch_size_per_gpu = 6
batch_size_per_gpu = 1

achieved_tflops = 171 # reported in the paper

Expand Down

0 comments on commit f0b6c84

Please sign in to comment.