From a121c592c80c25d8bd85531551b571d6167b00ee Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Wed, 18 Dec 2024 23:28:02 -0500 Subject: [PATCH] Fix Optimizer & LR scheduler & Consume Samples when Resuming in PEFT (#11631) * Fix Optimizer & LR scheduler Resume * fix unit test Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * typo Signed-off-by: Chen Cui * Fix consume samples * Fix unit tests * Apply isort and black reformatting Signed-off-by: suiyoubi --------- Signed-off-by: Chen Cui Signed-off-by: cuichenx Signed-off-by: suiyoubi Co-authored-by: Chen Cui Co-authored-by: cuichenx Co-authored-by: suiyoubi --- nemo/collections/llm/gpt/data/fine_tuning.py | 7 ++++--- nemo/lightning/pytorch/callbacks/peft.py | 7 ++++++- tests/lightning/test_data.py | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index 0d866bb600fe..a22ed72f4656 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -93,6 +93,7 @@ def __init__( self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size self.validate_batch_size_for_packed_sequence() self.dataset_kwargs = dataset_kwargs or {} + self.init_global_step = 0 def validate_batch_size_for_packed_sequence(self): """ @@ -163,9 +164,7 @@ def state_dict(self) -> Dict[str, Any]: A dictionary containing datamodule state. """ - consumed_samples = self.data_sampler.compute_consumed_samples( - self.trainer.global_step - self.data_sampler.init_global_step - ) + consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step) return {"consumed_samples": consumed_samples} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -240,6 +239,8 @@ def _create_dataset(self, path, is_test=False, **kwargs): def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader: # pylint: disable=C0115,C0116 + self.init_global_step = self.trainer.global_step + self.data_sampler.init_global_step = self.init_global_step return WrappedDataLoader( mode=mode, dataset=dataset, diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index d138117e4599..d2e93fe9ab42 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -204,7 +204,12 @@ def apply_transform(self, trainer): ) trainer.strategy.load_model_state_dict(adapter_state, strict=False) if trainer.state.fn == TrainerFn.FITTING: - trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True) + # Load optimizer + trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=False) + # Load lr scheduler + if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None: + for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers): + config.scheduler.load_state_dict(lrs_state) for cb in trainer.callbacks[::-1]: if isinstance(cb, MegatronOptimizerModule): diff --git a/tests/lightning/test_data.py b/tests/lightning/test_data.py index 2519616766f4..b848bec3dae9 100644 --- a/tests/lightning/test_data.py +++ b/tests/lightning/test_data.py @@ -15,11 +15,18 @@ from pathlib import Path from unittest.mock import MagicMock, patch +import pytest + + +@pytest.fixture +def trainer(): + return MagicMock() + @patch( 'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None ) -def test_finetuning_module(mock_gpt_sft_dataset) -> None: +def test_finetuning_module(mock_gpt_sft_dataset, trainer) -> None: from nemo.collections.llm.gpt.data import FineTuningDataModule dataset_root = 'random_root' @@ -30,6 +37,8 @@ def test_finetuning_module(mock_gpt_sft_dataset) -> None: global_batch_size=8, seed=1234, ) + datamodule.trainer = trainer + datamodule.setup(stage='train') datamodule.train_dataloader() mock_gpt_sft_dataset.assert_called_once() @@ -38,7 +47,7 @@ def test_finetuning_module(mock_gpt_sft_dataset) -> None: @patch( 'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None ) -def test_dolly_module(mock_gpt_sft_dataset) -> None: +def test_dolly_module(mock_gpt_sft_dataset, trainer) -> None: from nemo.collections.llm.gpt.data import DollyDataModule datamodule = DollyDataModule( @@ -47,6 +56,8 @@ def test_dolly_module(mock_gpt_sft_dataset) -> None: global_batch_size=8, seed=1234, ) + datamodule.trainer = trainer + datamodule.setup(stage='train') datamodule.train_dataloader() mock_gpt_sft_dataset.assert_called_once() @@ -55,7 +66,7 @@ def test_dolly_module(mock_gpt_sft_dataset) -> None: @patch( 'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None ) -def test_squad_module(mock_gpt_sft_dataset) -> None: +def test_squad_module(mock_gpt_sft_dataset, trainer) -> None: from nemo.collections.llm.gpt.data import SquadDataModule datamodule = SquadDataModule( @@ -64,6 +75,8 @@ def test_squad_module(mock_gpt_sft_dataset) -> None: global_batch_size=8, seed=1234, ) + datamodule.trainer = trainer + datamodule.setup(stage='train') datamodule.train_dataloader() mock_gpt_sft_dataset.assert_called_once()