From 9fb0a2819d0511c00928adaa6c439a52ee49d813 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Mon, 13 Nov 2023 12:42:56 -0800 Subject: [PATCH] add sharded dp allgather time --- llm_analysis/analysis.py | 161 +++++++++++++++++++++++---------------- 1 file changed, 97 insertions(+), 64 deletions(-) diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 424da4b..c7c5ad8 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -95,7 +95,6 @@ def __init__( hbm_memory_efficiency: float = None, intra_node_memory_efficiency: float = INTRA_NODE_MEMORY_EFFICIENCY, inter_node_memory_efficiency: float = INTER_NODE_MEMORY_EFFICIENCY, - intra_node_alltoall_efficiency: float = INTRA_NODE_ALLTOALL_EFFICIENCY, ) -> None: """LLMAnalysis constructor. @@ -117,7 +116,6 @@ def __init__( self.dtype_config = dtype_config self.intra_node_memory_efficiency = intra_node_memory_efficiency self.inter_node_memory_efficiency = inter_node_memory_efficiency - self.intra_node_alltoall_efficiency = intra_node_alltoall_efficiency if achieved_memory_bandwidth_GBs and hbm_memory_efficiency: logger.info( @@ -172,6 +170,10 @@ def __init__( " parallelism") self.total_num_params = self.get_num_params_total() + self.total_num_params_mlp = self.get_num_params_per_layer_mlp( + ) * self.model_config.num_layers + self.total_num_params_embedding = self.get_num_params_embedding() + self.total_num_params_others = self.total_num_params - self.total_num_params_mlp - self.total_num_params_embedding self.total_num_active_params = self.get_num_active_params_total() def update_model_config(self, model_config: ModelConfig) -> None: @@ -1043,11 +1045,11 @@ def get_latency_fwd_per_layer_attn( logger.debug( "latency_fwd_per_layer_attn:" f" {round(max(compute_latency, memory_latency)*1000, 3)} ms" - " (max(compute_latency, weight_memory_latency+" - " activation_memory_latency) =" + " (max(compute_latency, weight_memory_latency +" + " activation_memory_latency):" f" max({round(compute_latency*1000, 3)}," - f" ({round(weight_memory_latency*1000, 3)} +" - f" {round(activation_memory_latency*1000, 3)})))") + f" {round(weight_memory_latency*1000, 3)} +" + f" {round(activation_memory_latency*1000, 3)}))") return max(compute_latency, memory_latency) @@ -1055,8 +1057,12 @@ def get_latency_fwd_per_layer_mlp_moe_alltoall(self, batch_size: int, seq_len: int) -> float: data_nums = self.model_config.moe_top_k * batch_size * seq_len * self.model_config.hidden_dim data_bytes = data_nums * self.dtype_config.activation_bits / BITS_PER_BYTE - return data_bytes / (self.gpu_config.intra_node_bandwidth_in_GB_per_sec - * self.intra_node_alltoall_efficiency * 10**9) + + latency = data_bytes / (self.get_intra_node_bandwidth() * 10**9) + logger.info( + f'moe_alltoall data_bytes = {_num_to_string(data_bytes)}B, latency = {round(latency*1000, 3)} ms' + ) + return latency def get_latency_fwd_per_layer_mlp( self, @@ -1109,8 +1115,11 @@ def get_latency_fwd_per_layer_mlp( " (max(compute_latency, weight_memory_latency+" " activation_memory_latency) =" f" max({round(compute_latency*1000, 3)}," - f" ({round(weight_memory_latency*1000, 3)} +" - f" {round(activation_memory_latency*1000, 3)})))") + f" {round(weight_memory_latency*1000, 3)} +" + f" {round(activation_memory_latency*1000, 3)}))") + + logger.debug( + f'alltoall_latency = {round(alltoall_latency*1000, 3)} ms') return max(compute_latency, memory_latency) + alltoall_latency @@ -1196,7 +1205,7 @@ def time_allgather(S, n, B): latency_allgather_params_mlp = time_allgather( params_bytes_mlp, dp_size / ep_size, (self.get_intra_node_bandwidth() - if ep_size <= 8 else self.get_inter_node_bandwidth()) * 10**9) + if ep_size < 8 else self.get_inter_node_bandwidth()) * 10**9) latency_allgather_params_non_mlp = time_allgather( params_bytes_non_mlp, dp_size, @@ -1205,8 +1214,8 @@ def time_allgather(S, n, B): latency_fwd_per_layer_shared_dp_comm = latency_allgather_params_mlp + latency_allgather_params_non_mlp - logger.debug( - f'params_bytes_mlp: {_num_to_string(params_bytes_mlp)}B, params_bytes_non_mlp: {_num_to_string(params_bytes_non_mlp)}B,latency_allgather_params_mlp: {round(latency_allgather_params_mlp*1000, 3)} ms, latency_allgather_params_non_mlp: {round(latency_allgather_params_non_mlp*1000, 3)} ms' + logger.info( + f'params_bytes_mlp: {_num_to_string(params_bytes_mlp)}B, params_bytes_non_mlp: {_num_to_string(params_bytes_non_mlp)}B, latency_allgather_params_mlp: {round(latency_allgather_params_mlp*1000, 3)} ms, latency_allgather_params_non_mlp: {round(latency_allgather_params_non_mlp*1000, 3)} ms' ) return latency_fwd_per_layer_shared_dp_comm @@ -1219,6 +1228,7 @@ def get_latency_fwd_per_layer( activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE, layernorm_dtype_bytes: int = BYTES_FP32, + ds_zero: DSZeRO = DSZeRO.NONE, ) -> tuple: """Get the latency for the forward pass of a transformer layer, given the batch size, sequence length, training or inference, activation recomputation strategy, @@ -1232,10 +1242,14 @@ def get_latency_fwd_per_layer( is_inference (bool, optional): whether it is inference or not. Defaults to True. activation_recomputation (ActivationRecomputation, optional): activation recomputation strategy. Defaults to ActivationRecomputation.NONE. layernorm_dtype_bytes (int, optional): number of bytes in the data type for the layernorm activations. Defaults to BYTES_FP32. Often has to be FP32 in training to maintain model accuracy. + ds_zero (DSZeRO, optional): which DeepSpeed ZeRO stage to use. Defaults to DSZeRO.NONE (disabled). Returns: tuple: a tuple of the latency in seconds for the forward pass of a transformer layer and its breakdown dict """ + if ds_zero != ds_zero.NONE: + assert not is_inference, "DeepSpeed ZeRO is only supported in training" + latency_fwd_per_layer_attn = self.get_latency_fwd_per_layer_attn( batch_size, seq_len, is_inference, activation_recomputation) @@ -1258,35 +1272,36 @@ def get_latency_fwd_per_layer( seq_len, self.dtype_config.activation_bits / BITS_PER_BYTE, ) - - latency_fwd_per_layer_shared_dp_comm = self.get_latency_fwd_per_layer_shared_dp_comm( - ) logger.debug( f"latency_fwd_per_layer_tp_comm: {round(latency_fwd_per_layer_tp_comm*1000, 3)} ms" ) - latency_per_layer = (latency_fwd_per_layer_attn + - latency_fwd_per_layer_mlp + - 2 * latency_fwd_per_layer_layernorm + - 2 * latency_fwd_per_layer_tp_comm + - latency_fwd_per_layer_shared_dp_comm) - logger.debug( - f"latency_fwd_per_layer_shared_dp_comm: {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)} ms" + latency_fwd_per_layer_shared_dp_comm = self.get_latency_fwd_per_layer_shared_dp_comm( ) - logger.debug( - f"latency_per_layer: {round(latency_per_layer*1000, 3)} ms" - f" ({round(latency_fwd_per_layer_attn*1000, 3)} +" + latency_per_layer = latency_fwd_per_layer_attn + latency_fwd_per_layer_mlp + 2 * latency_fwd_per_layer_layernorm + 2 * latency_fwd_per_layer_tp_comm + + if ds_zero > DSZeRO.STAGE_1 and latency_fwd_per_layer_shared_dp_comm > latency_per_layer: + logger.warning( + f'allgather communication time to unshard model weight {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)} ms is larger than compute {round(latency_per_layer*1000, 3)} ms, thus cannot be fully overlapped.' + ) + latency_per_layer = max(latency_per_layer, + latency_fwd_per_layer_shared_dp_comm) + + logger.info( + f"latency_per_layer: {round(latency_per_layer*1000, 3)} ms (max(attn + mlp + 2*layernorm + 2*tp_comm, shared_dp_comm):" + f" max({round(latency_fwd_per_layer_attn*1000, 3)} +" f" {round(latency_fwd_per_layer_mlp*1000, 3)} +" f" {round(2*latency_fwd_per_layer_layernorm*1000, 3)} +" - f" {round(2*latency_fwd_per_layer_tp_comm*1000, 3)} +" - f" {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)})") + f" {round(2*latency_fwd_per_layer_tp_comm*1000, 3)}," + f" {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)}))") breakdown_per_layer = { "attn": latency_fwd_per_layer_attn, "mlp": latency_fwd_per_layer_mlp, "layernorm": 2 * latency_fwd_per_layer_layernorm, "tp_comm": 2 * latency_fwd_per_layer_tp_comm, + "sharded_dp_comm": latency_fwd_per_layer_shared_dp_comm } return latency_per_layer, breakdown_per_layer @@ -1341,6 +1356,7 @@ def get_latency_fwd( ActivationRecomputation = ActivationRecomputation.NONE, layernorm_dtype_bytes: int = BYTES_FP32, breakdown_prefix: str = "", + ds_zero: DSZeRO = DSZeRO.NONE, ) -> tuple: """Get the latency for the forward pass of the transformer, given the batch size, sequence length, and whether it is inference or not, the activation @@ -1354,6 +1370,7 @@ def get_latency_fwd( activation_recomputation (ActivationRecomputation, optional): activation recomputation strategy. Defaults to ActivationRecomputation.NONE. layernorm_dtype_bytes (int, optional): number of bytes in the data type for the layernorm activations. Defaults to BYTES_FP32. Often has to be FP32 in training to maintain model accuracy. breakdown_prefix (str, optional): prefix for the breakdown dict keys. Defaults to "". + ds_zero (DSZeRO, optional): which DeepSpeed ZeRO stage to use. Defaults to DSZeRO.NONE (disabled). Returns: tuple: a tuple of the latency in seconds for the forward pass of the transformer and its breakdown dict """ @@ -1369,9 +1386,10 @@ def get_latency_fwd( is_inference, activation_recomputation, layernorm_dtype_bytes, + ds_zero, ) - latency_fwd_all_layers = latency_fwd_per_layer * num_layers_per_gpu + latency_fwd_layers = latency_fwd_per_layer * num_layers_per_gpu latency_fwd_input_embedding = self.get_latency_fwd_input_embedding( batch_size, @@ -1382,23 +1400,23 @@ def get_latency_fwd( latency_fwd_output_embedding_loss = ( self.get_latency_fwd_output_embedding_loss(batch_size, seq_len)) - total_latency = (latency_fwd_all_layers + latency_fwd_input_embedding + - latency_fwd_output_embedding_loss) + latency_fwd = (latency_fwd_layers + latency_fwd_input_embedding + + latency_fwd_output_embedding_loss) - logger.debug("latency_fwd_all_layers:" - f" {round(latency_fwd_all_layers*1000, 3)} ms" - f" ({round(latency_fwd_per_layer*1000, 3)} ms x" - f" {num_layers_per_gpu}), latency_fwd_input_embedding:" - f" {round(latency_fwd_input_embedding*1000, 3)} ms," - " latency_fwd_output_embedding_loss:" - f" {round(latency_fwd_output_embedding_loss*1000, 3)} ms") + logger.info("latency_fwd_layers:" + f" {round(latency_fwd_layers*1000, 3)} ms" + f" ({round(latency_fwd_per_layer*1000, 3)} ms x" + f" {num_layers_per_gpu}), latency_fwd_input_embedding:" + f" {round(latency_fwd_input_embedding*1000, 3)} ms," + " latency_fwd_output_embedding_loss:" + f" {round(latency_fwd_output_embedding_loss*1000, 3)} ms") - logger.debug(f"latency_fwd_total: {round(total_latency*1000, 3)} ms" - f" ({round(latency_fwd_all_layers*1000, 3)} +" - f" {round(latency_fwd_input_embedding*1000, 3)} +" - f" {round(latency_fwd_output_embedding_loss*1000, 3)})") + logger.info( + f"latency_fwd: {round(latency_fwd*1000, 3)} ms (layers + input_embedding + output_embedding_loss: " + f"{round(latency_fwd_layers*1000, 3)} + {round(latency_fwd_input_embedding*1000, 3)} + {round(latency_fwd_output_embedding_loss*1000, 3)})" + ) - total_breakdown = { + latency_fwd_breakdown = { breakdown_prefix + "latency_fwd_attn": breakdown_per_layer["attn"] * num_layers_per_gpu, breakdown_prefix + "latency_fwd_mlp": @@ -1407,12 +1425,18 @@ def get_latency_fwd( breakdown_per_layer["layernorm"] * num_layers_per_gpu, breakdown_prefix + "latency_fwd_tp_comm": breakdown_per_layer["tp_comm"] * num_layers_per_gpu, + breakdown_prefix + "latency_fwd_sharded_dp_comm": + breakdown_per_layer["sharded_dp_comm"] * num_layers_per_gpu, breakdown_prefix + "latency_fwd_input_embedding": latency_fwd_input_embedding, breakdown_prefix + "latency_fwd_output_embedding_loss": latency_fwd_output_embedding_loss, } - return total_latency, total_breakdown + return latency_fwd, latency_fwd_breakdown + + def get_latency_weight_update(self, ): + return self.weight_grad_op_state_memory_per_gpu / ( + self.get_gpu_hbm_bandwidth() * 10**9) def print_config(self, name="Training Configs") -> None: config_str = f"\n{name.center(PRINT_LINE_WIDTH, '-')}\n" @@ -1985,9 +2009,10 @@ def training( optimizer_state_memory_per_gpu = optimizer_state_memory_per_layer * num_layers_per_gpu + optimizer_state_memory_embedding + optimizer_state_memory_last_layernorm gradient_memory_per_gpu = gradient_memory_per_layer * num_layers_per_gpu + gradient_memory_embedding + gradient_memory_last_layernorm + self.weight_grad_op_state_memory_per_gpu = weight_memory_per_gpu + gradient_memory_per_gpu + optimizer_state_memory_per_gpu + memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 - - weight_memory_per_gpu - optimizer_state_memory_per_gpu - - gradient_memory_per_gpu) + self.weight_grad_op_state_memory_per_gpu) logger.info( f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B" @@ -2142,24 +2167,24 @@ def training( is_inference=False, activation_recomputation=activation_recomputation, layernorm_dtype_bytes=layernorm_dtype_bytes, + ds_zero=ds_zero, ) + # estimated by flops only: + # latency_per_micro_batch = num_flops_total_per_micro_batch / ( + # (self.parallelism_config.tp_size * self.parallelism_config.pp_size) + # * self.get_TFLOPS_per_gpu() * 1e12) + latency_per_micro_batch = latency_fwd * 3 + latency_weight_update = self.get_latency_weight_update() - mp_size = (self.parallelism_config.tp_size * - self.parallelism_config.pp_size) - - latency_per_micro_batch = num_flops_total_per_micro_batch / ( - mp_size * self.get_TFLOPS_per_gpu() * 1e12) - - latency_per_iter = (latency_per_micro_batch * - gradient_accumulation_steps) + latency_per_iter = ( + latency_per_micro_batch * gradient_accumulation_steps + + latency_weight_update) logger.info( - "latency_per_micro_batch:" - f" {round(latency_per_micro_batch * 1000, 3)} ms, latency_fwd:" - f" {round(latency_fwd * 1000, 3)} ms, \nlatency_per_iter:" - f" {round(latency_per_iter * 1000, 3)} ms" - f" ({round(latency_per_micro_batch * 1000, 3)} ms *" - f" {gradient_accumulation_steps} gradient_accumulation_steps)") + f"latency_per_micro_batch: {round(latency_per_micro_batch * 1000, 3)} ms, " + f"latency_per_iter: {round(latency_per_iter * 1000, 3)} ms " + f"({round(latency_per_micro_batch * 1000, 3)} ms latency_fwd * {gradient_accumulation_steps} gradient_accumulation_steps + {round(latency_weight_update * 1000, 3)} ms weight_update)" + ) total_num_gpus = (self.parallelism_config.tp_size * self.parallelism_config.pp_size * @@ -2230,6 +2255,12 @@ def training( total_num_tokens, "num_params_total": self.total_num_params, + "num_params_total_mlp": + self.total_num_params_mlp, + "num_params_total_embedding": + self.total_num_params_embedding, + "num_params_total_others": + self.total_num_params_others, "num_active_params_total": self.total_num_active_params, "activation_recomputation": @@ -2263,8 +2294,7 @@ def training( "optimizer_state_memory_per_gpu": optimizer_state_memory_per_gpu, "(weight+op_state+grad)_memory_per_gpu": - weight_memory_per_gpu + gradient_memory_per_gpu + - optimizer_state_memory_per_gpu, + self.weight_grad_op_state_memory_per_gpu, "activation_memory_batch_size_1": activation_memory_batch_size_1, "activation_memory_per_gpu": @@ -2278,8 +2308,8 @@ def training( "activation_memory_embedding_output_per_gpu": activation_memory_embedding_output_per_gpu, "(weight+op_state+grad+act)_memory_per_gpu": - weight_memory_per_gpu + gradient_memory_per_gpu + - optimizer_state_memory_per_gpu + activation_memory_per_gpu, + self.weight_grad_op_state_memory_per_gpu + + activation_memory_per_gpu, "memory_left_per_gpu": memory_left, "latency_per_micro_batch": @@ -2290,6 +2320,7 @@ def training( summary_dict.update(latency_fwd_breakdown) summary_dict.update({ "latency_per_iter": latency_per_iter, + "iters_per_sec": round(1 / latency_per_iter, 2), "total_training_latency": total_training_latency, "gpu_hours": gpu_hours, }) @@ -2511,6 +2542,8 @@ def train( else: dp_size = 1 + assert ep_size <= 8, "only support ep_size up to 8 GPUs per node" + model_config = get_model_config_by_name(model_name) gpu_config = get_gpu_config_by_name(gpu_name) dtype_config = get_dtype_config_by_name(dtype_name)