From f4157bea69f3df8c6cb66f2ebcda66ba03d1288e Mon Sep 17 00:00:00 2001 From: hotsuyuki Date: Tue, 24 Dec 2024 09:46:22 -0500 Subject: [PATCH] Fix import error in `deepspeed_to_megatron.py` (#455) Previously, `deepspeed_to_megatron.py` would raise an import error due to the relative import. This commit fixes this issue by changing from the relative import to the absolute import like in `deepspeed_to_transformers.py`. --- tools/convert_checkpoint/deepspeed_to_megatron.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index ef1c77e546..f9116b7da0 100755 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -4,7 +4,7 @@ import os import torch from collections import OrderedDict -from .deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint +from deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint MODEL_KEY = 'model' ARGS_KEY = 'args' @@ -92,7 +92,7 @@ def _create_rank_checkpoint(ds_checkpoint, tp_index, pp_index, for_release=False if pp_index == 0: meg_embedding_sd.update(nested_embedding_sd) - if pp_index == ds_checkpoint.pp_degree -1: + if pp_index == ds_checkpoint.pp_degree - 1: for key, value in embedding_sd.items(): if key.startswith(WORD_EMBEDDINGS_KEY): fields = key.split('.') @@ -111,7 +111,7 @@ def _create_rank_checkpoint(ds_checkpoint, tp_index, pp_index, for_release=False if pp_index == 0: checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd - if pp_index == ds_checkpoint.pp_degree -1: + if pp_index == ds_checkpoint.pp_degree - 1: checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args()