From a7f0bc1903493888c31436efc2452ff721fa5a67 Mon Sep 17 00:00:00 2001 From: Gerald Shen <119401249+gshennvm@users.noreply.github.com> Date: Fri, 1 Dec 2023 12:30:00 -0800 Subject: [PATCH 1/4] only enable query key scaling during fp16 (#7946) * only enable query key scaling during fp16 Signed-off-by: Gerald Shen * add warning Signed-off-by: Gerald Shen * fixup! only enable query key scaling during fp16 Signed-off-by: Gerald Shen * remove var from jenkens file Signed-off-by: Gerald Shen * fix test by setting TE var Signed-off-by: Gerald Shen * set to 0 if disabled Signed-off-by: Gerald Shen --------- Signed-off-by: Gerald Shen --- Jenkinsfile | 3 --- .../models/language_modeling/megatron_gpt_model.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 9f50ace7adb8..1f974333dd3a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -9,9 +9,6 @@ pipeline { timeout(time: 8, unit: 'HOURS') disableConcurrentBuilds(abortPrevious: true) } - environment { - NVTE_APPLY_QK_LAYER_SCALING = 1 - } stages { diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index e66708d2d2dd..5b14532016c5 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1544,6 +1544,19 @@ def build_transformer_config(self) -> TransformerConfig: attention_softmax_in_fp32 = False # not currently used in NeMo unless apply_query_key_layer_scaling is True apply_query_key_layer_scaling = self.cfg.get('apply_query_key_layer_scaling', False) + + fp16_enabled = self.trainer.precision in [16, '16', '16-mixed'] + if apply_query_key_layer_scaling: + if fp16_enabled: + os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "1" + else: + logging.warning( + "apply_query_key_layer_scaling is only enabled when using FP16, setting it to False " + "and setting NVTE_APPLY_QK_LAYER_SCALING=0" + ) + os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "0" + apply_query_key_layer_scaling = False + if apply_query_key_layer_scaling: attention_softmax_in_fp32 = True @@ -1570,6 +1583,7 @@ def build_transformer_config(self) -> TransformerConfig: # any configs that are not in the nemo model config will be added here config_mapping = { + 'apply_query_key_layer_scaling': apply_query_key_layer_scaling, 'apply_residual_connection_post_layernorm': False, # we don't use this in NeMo 'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma, 'add_bias_linear': add_bias_linear, From d1021186c7066a30a2f23a64ef540739f23e36b7 Mon Sep 17 00:00:00 2001 From: Yi Dong <43824965+yidong72@users.noreply.github.com> Date: Fri, 1 Dec 2023 16:16:16 -0500 Subject: [PATCH 2/4] added bf16 support (#7888) Signed-off-by: Yi Dong --- .../distributed_checkpoint_averaging.py | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/scripts/checkpoint_averaging/distributed_checkpoint_averaging.py b/scripts/checkpoint_averaging/distributed_checkpoint_averaging.py index 6939cc9b36b5..89b1430198b3 100644 --- a/scripts/checkpoint_averaging/distributed_checkpoint_averaging.py +++ b/scripts/checkpoint_averaging/distributed_checkpoint_averaging.py @@ -39,7 +39,8 @@ import logging import os import shutil - +import numpy as np +import tensorstore # need to import it for bf16 support import zarr logging.basicConfig(level=logging.INFO) @@ -84,6 +85,7 @@ def main(): n = len(checkpoint_paths) # initialize dict, will be used to store the weights that need to be averaged avg_weights = {} + chunk_info = {} logging.info(f"Averaging {n} checkpoints ... {'at steps:' + str(args.steps) if args.steps is not None else ''}") @@ -114,21 +116,22 @@ def main(): if item not in avg_weights: logging.info(f"Initialized average weights dict with: {item}") - avg_weights[item] = zarr.open(os.path.join(full_path, item), mode='r') + array = zarr.open(os.path.join(full_path, item), mode='r') + avg_weights[item] = array[:] + chunk_info[item] = array.chunks else: logging.info(f"Updated average weights dict with weight: {item}") array_z = zarr.open(os.path.join(full_path, item), mode='r') - sum_array = avg_weights[item][:] + array_z[:] - avg_weights[item] = zarr.array(sum_array, chunks=array_z.chunks, dtype=array_z.dtype) + sum_array = avg_weights[item] + array_z[:] + avg_weights[item] = sum_array for k in avg_weights: logging.info(f"Average weights dict key : {k}, dtype : {avg_weights[k].dtype}, shape : {avg_weights[k].shape}") if str(avg_weights[k].dtype).startswith("int"): raise ValueError("Int type not supported") else: - array_z = avg_weights[k][:] - array_z = array_z / n - avg_weights[k] = zarr.array(array_z, chunks=avg_weights[k].chunks, dtype=avg_weights[k].dtype) + array_z = avg_weights[k] / n + avg_weights[k] = array_z # Save model if args.steps is None: @@ -140,7 +143,24 @@ def main(): # save avg_weights for k in avg_weights: logging.info(f"Saving {k} to {ckpt_name}") - zarr.save(os.path.join(ckpt_name, k), avg_weights[k]) + input_arr = avg_weights[k] + chunks = chunk_info[k] + # create the zarr array + output_array = zarr.create( + input_arr.shape, + dtype=input_arr.dtype, + store=os.path.join(ckpt_name, k), + chunks=chunks, + compressor=None, + fill_value=None, + write_empty_chunks=True, + ) + if input_arr.dtype == np.dtype('bfloat16'): + arr = output_array + arr._dtype = input_arr.dtype + zarray = arr.store['.zarray'] + arr.store['.zarray'] = zarray.replace(b' Date: Fri, 1 Dec 2023 15:05:38 -0800 Subject: [PATCH 3/4] Proposed WAR for gpt3 eval hang with PP (#7927) Signed-off-by: yaoyu-33 Co-authored-by: Eric Harper --- .../nlp/modules/common/text_generation_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 7951ff563290..2c4c07cca346 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -507,12 +507,8 @@ def synced_generate( if compute_logprob: precision = model._trainer.precision - if precision in [16, "16"]: - dtype = torch.float16 - elif precision in ['bf16', 'bf16-mixed']: - dtype = torch.bfloat16 - else: - dtype = torch.float32 + dtype = torch.float32 + output_logits = torch.empty( tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda") ) From ae5d7e81b8e446e5650082b1700eb92dd2e7c1bd Mon Sep 17 00:00:00 2001 From: Igor Gitman Date: Sun, 3 Dec 2023 15:00:42 -0800 Subject: [PATCH 4/4] Pass in rotary_base to mcore and from HF (#7933) * Pass in rotary_base to mcore and from HF Signed-off-by: Igor Gitman * Allow changing rotary_base from the sft config file Signed-off-by: Igor Gitman * Update mcore in jenkins Signed-off-by: Igor Gitman --------- Signed-off-by: Igor Gitman Co-authored-by: Eric Harper --- Jenkinsfile | 4 ++-- examples/nlp/language_modeling/tuning/megatron_gpt_sft.py | 3 +++ .../nlp/models/language_modeling/megatron_gpt_model.py | 1 + scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py | 2 ++ 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 1f974333dd3a..12fafac57a67 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -72,8 +72,8 @@ pipeline { steps { sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout e122536b7645edcb7ebf099b5c92a443f7dbf8e7 && \ - pip install -e .' + git checkout 973330e9c3681604703bf1eb6b5a265d1b9b9b38 && \ + pip install .' } } diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py index 79dd20fcf84a..b6325be40829 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py @@ -90,6 +90,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): if cfg.model.get('seq_len_interpolation_factor', None) is not None: gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor + if cfg.model.get('rotary_base', None) is not None: + gpt_cfg.rotary_base = cfg.model.rotary_base + sft_cls = MegatronGPTSFTModel gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 5b14532016c5..c2e39ea03a3e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -318,6 +318,7 @@ def model_provider_func(self, pre_process, post_process): position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'), rotary_percent=self.cfg.get('rotary_percentage', 1.0), seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None), + rotary_base=self.cfg.get('rotary_base', 10000), ) else: assert self.cfg.get('num_query_groups', None) is None or self.cfg.get( diff --git a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py index c281088f8c5c..d1453aeee972 100644 --- a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py @@ -116,6 +116,8 @@ def load_config(args, llama_config): nemo_config['seq_len_interpolation_factor'] = llama_config['rope_scaling']['factor'] else: raise ValueError("Only linear rope scaling type is supported now") + if llama_config['rope_theta'] is not None: + nemo_config['rotary_base'] = llama_config['rope_theta'] base = 128 while llama_config['vocab_size'] % base != 0: