Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into microsoft-main-fpdt
  • Loading branch information
saforem2 committed Dec 25, 2024
2 parents 188d37b + f4157be commit 1a21057
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tools/convert_checkpoint/deepspeed_to_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import torch
from collections import OrderedDict
from .deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint
from deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint

MODEL_KEY = 'model'
ARGS_KEY = 'args'
Expand Down Expand Up @@ -92,7 +92,7 @@ def _create_rank_checkpoint(ds_checkpoint, tp_index, pp_index, for_release=False
if pp_index == 0:
meg_embedding_sd.update(nested_embedding_sd)

if pp_index == ds_checkpoint.pp_degree -1:
if pp_index == ds_checkpoint.pp_degree - 1:
for key, value in embedding_sd.items():
if key.startswith(WORD_EMBEDDINGS_KEY):
fields = key.split('.')
Expand All @@ -111,7 +111,7 @@ def _create_rank_checkpoint(ds_checkpoint, tp_index, pp_index, for_release=False
if pp_index == 0:
checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd
checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd
if pp_index == ds_checkpoint.pp_degree -1:
if pp_index == ds_checkpoint.pp_degree - 1:
checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd

checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args()
Expand Down

0 comments on commit 1a21057

Please sign in to comment.