diff --git a/examples/multimodal/multimodal_llm/neva/eval/vqa_science.py b/examples/multimodal/multimodal_llm/neva/eval/vqa_science.py index 91b8930a0f46..17bda5725eb4 100644 --- a/examples/multimodal/multimodal_llm/neva/eval/vqa_science.py +++ b/examples/multimodal/multimodal_llm/neva/eval/vqa_science.py @@ -169,7 +169,6 @@ def eval_model(args): parser.add_argument("--image-folder", type=str, default="") parser.add_argument("--question-file", type=str, default="tables/question.json") parser.add_argument("--answers-file", type=str, default="answer.jsonl") - parser.add_argument("--conv-mode", type=str, default="llava_v0") # this flag has no use! parser.add_argument("--tp", type=int, default=1) parser.add_argument("--pp", type=int, default=1) parser.add_argument("--num-chunks", type=int, default=1) diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 30535baae3e2..09ba2636f44e 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -79,7 +79,7 @@ from megatron.core import InferenceParams, dist_checkpointing, parallel_state from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.pipeline_parallel.schedules import get_forward_backward_func - from megatron.core.utils import make_sharded_tensor_for_checkpoint + from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint HAVE_MEGATRON_CORE = True @@ -248,14 +248,7 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), **kw state_dict = self.state_dict(prefix='', keep_vars=True) state_dict.pop('weight') # duplicate everything else - for layer_name in state_dict.keys(): - tensor = state_dict[layer_name] - layer_key = f'{prefix}{layer_name}' - sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint( - tensor, - layer_key, - prepend_offsets=sharded_offsets, - ) + sharded_state_dict.update(make_sharded_tensors_for_checkpoint(state_dict, prefix=prefix)) return sharded_state_dict diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index c0a1a0df631d..8112071b606a 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -6,6 +6,8 @@ from lightning_fabric.utilities.cloud_io import get_filesystem from lightning_fabric.utilities.types import _PATH from megatron.core import dist_checkpointing +from megatron.core.dist_checkpointing.dict_utils import extract_matching_values +from megatron.core.dist_checkpointing.mapping import ShardedBase from megatron.core.dist_checkpointing.strategies import tensorstore from nemo.utils import logging @@ -81,14 +83,36 @@ def load_checkpoint( sharded_strategy = None if not strict: - for key in list(sharded_state_dict['state_dict'].keys()): - if not os.path.isdir(f"{path}/{key}"): - sharded_state_dict['state_dict'].pop(key) + sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) return dist_checkpointing.load( sharded_state_dict=sharded_state_dict, checkpoint_dir=path, sharded_strategy=sharded_strategy ) + def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): + ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path) + loaded_keys = [] + missing_keys = [] + unexpected_keys = [] + + def should_remove_missing_sharded_base(x: Any): + if isinstance(x, ShardedBase): + if x.key in ckpt_sharded_metadata: + loaded_keys.append(x.key) + return False + else: + unexpected_keys.append(x.key) + return True + return False + + _, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base) + logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}') + + # TODO: compute missing_keys by: + # 1. all_gather_object of loaded_keys + # 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys + return sharded_state_dict + def remove_checkpoint(self, path: _PATH) -> None: """Remove a distributed checkpoint.