Skip to content

Commit

Permalink
Merge branch 'yuya/neva_llama3' into siglip_merge_llama3
Browse files Browse the repository at this point in the history
  • Loading branch information
HuiyingLi committed May 17, 2024
2 parents de85369 + 5e5ad9e commit f0c0a35
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def eval_model(args):
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.json")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
parser.add_argument("--conv-mode", type=str, default="llava_v0") # this flag has no use!
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--num-chunks", type=int, default=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from megatron.core import InferenceParams, dist_checkpointing, parallel_state
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.utils import make_sharded_tensor_for_checkpoint
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -248,14 +248,7 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), **kw
state_dict = self.state_dict(prefix='', keep_vars=True)
state_dict.pop('weight')
# duplicate everything else
for layer_name in state_dict.keys():
tensor = state_dict[layer_name]
layer_key = f'{prefix}{layer_name}'
sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint(
tensor,
layer_key,
prepend_offsets=sharded_offsets,
)
sharded_state_dict.update(make_sharded_tensors_for_checkpoint(state_dict, prefix=prefix))
return sharded_state_dict


Expand Down
30 changes: 27 additions & 3 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from lightning_fabric.utilities.cloud_io import get_filesystem
from lightning_fabric.utilities.types import _PATH
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.dict_utils import extract_matching_values
from megatron.core.dist_checkpointing.mapping import ShardedBase
from megatron.core.dist_checkpointing.strategies import tensorstore

from nemo.utils import logging
Expand Down Expand Up @@ -81,14 +83,36 @@ def load_checkpoint(
sharded_strategy = None

if not strict:
for key in list(sharded_state_dict['state_dict'].keys()):
if not os.path.isdir(f"{path}/{key}"):
sharded_state_dict['state_dict'].pop(key)
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)

return dist_checkpointing.load(
sharded_state_dict=sharded_state_dict, checkpoint_dir=path, sharded_strategy=sharded_strategy
)

def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
loaded_keys = []
missing_keys = []
unexpected_keys = []

def should_remove_missing_sharded_base(x: Any):
if isinstance(x, ShardedBase):
if x.key in ckpt_sharded_metadata:
loaded_keys.append(x.key)
return False
else:
unexpected_keys.append(x.key)
return True
return False

_, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')

# TODO: compute missing_keys by:
# 1. all_gather_object of loaded_keys
# 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
return sharded_state_dict

def remove_checkpoint(self, path: _PATH) -> None:
"""Remove a distributed checkpoint.
Expand Down

0 comments on commit f0c0a35

Please sign in to comment.