Skip to content

Commit

Permalink
Fix Optimizer & LR scheduler & Consume Samples when Resuming in PEFT (#…
Browse files Browse the repository at this point in the history
…11631)

* Fix Optimizer & LR scheduler Resume

* fix unit test

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* typo

Signed-off-by: Chen Cui <[email protected]>

* Fix consume samples

* Fix unit tests

* Apply isort and black reformatting

Signed-off-by: suiyoubi <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Signed-off-by: suiyoubi <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: cuichenx <[email protected]>
Co-authored-by: suiyoubi <[email protected]>
  • Loading branch information
4 people authored Dec 19, 2024
1 parent 093ffc4 commit a121c59
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
7 changes: 4 additions & 3 deletions nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size
self.validate_batch_size_for_packed_sequence()
self.dataset_kwargs = dataset_kwargs or {}
self.init_global_step = 0

def validate_batch_size_for_packed_sequence(self):
"""
Expand Down Expand Up @@ -163,9 +164,7 @@ def state_dict(self) -> Dict[str, Any]:
A dictionary containing datamodule state.
"""
consumed_samples = self.data_sampler.compute_consumed_samples(
self.trainer.global_step - self.data_sampler.init_global_step
)
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {"consumed_samples": consumed_samples}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -240,6 +239,8 @@ def _create_dataset(self, path, is_test=False, **kwargs):

def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader:
# pylint: disable=C0115,C0116
self.init_global_step = self.trainer.global_step
self.data_sampler.init_global_step = self.init_global_step
return WrappedDataLoader(
mode=mode,
dataset=dataset,
Expand Down
7 changes: 6 additions & 1 deletion nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,12 @@ def apply_transform(self, trainer):
)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)
if trainer.state.fn == TrainerFn.FITTING:
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True)
# Load optimizer
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=False)
# Load lr scheduler
if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None:
for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers):
config.scheduler.load_state_dict(lrs_state)

for cb in trainer.callbacks[::-1]:
if isinstance(cb, MegatronOptimizerModule):
Expand Down
19 changes: 16 additions & 3 deletions tests/lightning/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest


@pytest.fixture
def trainer():
return MagicMock()


@patch(
'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None
)
def test_finetuning_module(mock_gpt_sft_dataset) -> None:
def test_finetuning_module(mock_gpt_sft_dataset, trainer) -> None:
from nemo.collections.llm.gpt.data import FineTuningDataModule

dataset_root = 'random_root'
Expand All @@ -30,6 +37,8 @@ def test_finetuning_module(mock_gpt_sft_dataset) -> None:
global_batch_size=8,
seed=1234,
)
datamodule.trainer = trainer
datamodule.setup(stage='train')

datamodule.train_dataloader()
mock_gpt_sft_dataset.assert_called_once()
Expand All @@ -38,7 +47,7 @@ def test_finetuning_module(mock_gpt_sft_dataset) -> None:
@patch(
'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None
)
def test_dolly_module(mock_gpt_sft_dataset) -> None:
def test_dolly_module(mock_gpt_sft_dataset, trainer) -> None:
from nemo.collections.llm.gpt.data import DollyDataModule

datamodule = DollyDataModule(
Expand All @@ -47,6 +56,8 @@ def test_dolly_module(mock_gpt_sft_dataset) -> None:
global_batch_size=8,
seed=1234,
)
datamodule.trainer = trainer
datamodule.setup(stage='train')

datamodule.train_dataloader()
mock_gpt_sft_dataset.assert_called_once()
Expand All @@ -55,7 +66,7 @@ def test_dolly_module(mock_gpt_sft_dataset) -> None:
@patch(
'nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset.__init__', return_value=None
)
def test_squad_module(mock_gpt_sft_dataset) -> None:
def test_squad_module(mock_gpt_sft_dataset, trainer) -> None:
from nemo.collections.llm.gpt.data import SquadDataModule

datamodule = SquadDataModule(
Expand All @@ -64,6 +75,8 @@ def test_squad_module(mock_gpt_sft_dataset) -> None:
global_batch_size=8,
seed=1234,
)
datamodule.trainer = trainer
datamodule.setup(stage='train')

datamodule.train_dataloader()
mock_gpt_sft_dataset.assert_called_once()
Expand Down

0 comments on commit a121c59

Please sign in to comment.