Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NEFTune implementation for Noised Embedding Instruction Fine-Tuning support #3744

Merged
merged 10 commits into from
Nov 23, 2023
15 changes: 14 additions & 1 deletion ludwig/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -15,6 +15,7 @@
from ludwig.features.feature_registries import get_input_type_registry, get_output_type_registry
from ludwig.features.feature_utils import LudwigFeatureDict
from ludwig.modules.metric_modules import LudwigMetric
from ludwig.modules.training_hooks import TrainingHook
from ludwig.schema.features.base import BaseInputFeatureConfig, BaseOutputFeatureConfig, FeatureCollection
from ludwig.utils.algorithms_utils import topological_sort_feature_dependencies
from ludwig.utils.metric_utils import get_scalar_from_ludwig_metric
Expand Down Expand Up @@ -55,6 +56,9 @@ def __init__(self, random_seed: int = None):
self._eval_loss_metric = ModuleWrapper(torchmetrics.MeanMetric())
self._eval_additional_losses_metrics = ModuleWrapper(torchmetrics.MeanMetric())

# ================ Training Hook Handles ================
self._forward_hook_handles: List[TrainingHook] = []

def create_feature_dict(self) -> LudwigFeatureDict:
"""Creates and returns a LudwigFeatureDict."""
return LudwigFeatureDict()
Expand Down Expand Up @@ -340,6 +344,15 @@ def use_generation_config(self, generation_config: Dict[str, Any]):
raise NotImplementedError(f"{self.__class__.__name__} does not support generation_config. ")
yield

def _activate_forward_hooks(self):
"""Activates/registers forward hooks for the model."""
pass

def _deactivate_forward_hooks(self) -> None:
"""Deactivates/de-registers forward hooks for the model (if needed)."""
for handle in self._forward_hook_handles:
handle.deactivate_hook()


def create_input_feature(feature_config: BaseInputFeatureConfig, encoder_obj: Optional[Encoder]) -> InputFeature:
input_feature_cls = get_from_registry(feature_config.type, get_input_type_registry())
Expand Down
17 changes: 17 additions & 0 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ludwig.features.text_feature import TextOutputFeature
from ludwig.globals import MODEL_WEIGHTS_FILE_NAME
from ludwig.models.base import BaseModel
from ludwig.modules.training_hooks import NEFTuneHook
from ludwig.schema.features.base import BaseOutputFeatureConfig, FeatureCollection
from ludwig.schema.model_types.llm import LLMModelConfig
from ludwig.utils.augmentation_utils import AugmentationPipelines
Expand Down Expand Up @@ -756,6 +757,22 @@ def _update_target_tensor_for_finetuning(

return _targets

def _activate_forward_hooks(self):
"""Activates/registers forward hooks for the model."""
if not self.config_obj.model_parameters:
return

# Initialize forward hook handles
if self.config_obj.model_parameters.neftune_noise_alpha:
self._forward_hook_handles.append(
NEFTuneHook(neftune_noise_alpha=self.config_obj.model_parameters.neftune_noise_alpha)
)

# Activate forward hooks iteratively
for hook in self._forward_hook_handles:
# Update the model with the forward hooks in place
self.model = hook.activate_hook(self.model)

@staticmethod
def get_augmentation_pipelines() -> AugmentationPipelines:
"""Returns the augmentation pipeline for this model."""
Expand Down
103 changes: 103 additions & 0 deletions ludwig/modules/training_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging
from abc import ABC, abstractmethod

import torch

logger = logging.getLogger(__name__)


class TrainingHook(ABC):
"""A base class for training hooks in PyTorch.

This class provides a template for implementing custom training hooks
that can be activated, deactivated, and maintain a handle to the hook.

Attributes:
_hook_handle (Optional[torch.utils.hooks.RemovableHandle]): A handle to the
registered forward hook, initially set to None.
"""

def __init__(self, **kwargs) -> None:
self._hook_handle = None
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def hook_fn(self, module: torch.nn.Module, inputs: torch.tensor, outputs: torch.tensor) -> torch.tensor:
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
"""Abstract method to be implemented by subclasses. This is the method that defines the custom behavior of
the training hook during a forward pass for the specified module.

Args:
module (nn.Module): The PyTorch module for which the hook is activated.
inputs (torch.Tensor): The input to the module during the forward pass.
outputs (torch.Tensor): The output from the module during the forward pass.

Returns:
torch.Tensor: The output tensor from the module.

Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
pass

def activate_hook(self, module: torch.nn.Module) -> torch.nn.Module:
"""Activates the training hook for a given module.

Args:
module (nn.Module): The PyTorch module for which the hook is activated.

Returns:
nn.Module: The input module with the training hook activated.
"""
self._hook_handle = module.register_forward_hook(self.hook_fn)
return module

def deactivate_hook(self):
"""Deactivates and removes the training hook."""
if self._hook_handle is not None:
self._hook_handle.remove()
self._hook_handle = None


class NEFTuneHook(TrainingHook):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
self.neftune_noise_alpha = kwargs.get("neftune_noise_alpha")

def hook_fn(self, module: torch.nn.Module, input: torch.tensor, output: torch.tensor) -> torch.tensor:
"""Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
torch.nn. Embedding layers. This method is slightly adapted from the original source code that can be found
here: https://github.com/neelsjain/NEFTune.

The input tensor is ignored since the noise is added to the output of the embedding layer.

Returns:
torch.Tensor: The output tensor from the module.
"""
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output

def activate_hook(self, module: torch.nn.Module) -> torch.nn.Module:
"""Activates the neftune as presented in this code and paper:

Code: https://github.com/neelsjain/NEFTune
Paper: https://arxiv.org/abs/2310.05914

Args:
module (nn.Module): The PyTorch module for which the hook is activated.

Returns:
nn.Module: The input module with the training hook activated.
"""
from peft import PeftModel

if isinstance(module, PeftModel):
embeddings = module.base_model.model.get_input_embeddings()
else:
embeddings = module.get_input_embeddings()

embeddings.neftune_noise_alpha = self.neftune_noise_alpha
self._hook_handle = embeddings.register_forward_hook(self.hook_fn)

return module
12 changes: 12 additions & 0 deletions ludwig/schema/llms/model_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,22 @@ def _jsonschema_type_mapping(self):
class ModelParametersConfig(schema_utils.BaseMarshmallowConfig):
rope_scaling: RoPEScalingConfig = RoPEScalingConfigField().get_default_field()

neftune_noise_alpha: Optional[int] = schema_utils.IntegerRange(
default=0,
min=0,
allow_none=True,
description="The alpha parameter for the embedding noise, which controls the amount of noise added to the "
"embeddings. The higher the value, the more noise is added. This is based on the paper NEFTune: Noisy "
"Embeddings Improve Instruction Finetuning. Paper: https://arxiv.org/pdf/2310.05914.pdf. Default: 0."
"Suggested values: 5, 10",
)

def to_dict(self):
config = {}
if self.rope_scaling:
config["rope_scaling"] = self.rope_scaling.to_dict()
if self.neftune_noise_alpha:
config["neftune_noise_alpha"] = self.neftune_noise_alpha
return config


Expand Down
6 changes: 6 additions & 0 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def prepare(self):
# We may need to replace the embedding layer when using 8-bit optimizers from bitsandbytes.
update_embedding_layer(self.compiled_model, self.config)

# Register any post forward hooks for the model
self.compiled_model._activate_forward_hooks()

# Enable gradient checkpointing if configured
if self.config.enable_gradient_checkpointing:
# TODO(Arnav): Add support for gradient checkpointing in the compiled model
Expand Down Expand Up @@ -1017,6 +1020,9 @@ def train(
coordinator_only=False,
)

# Deactivate any forward hooks for the model used at training time.
self.compiled_model._deactivate_forward_hooks()

# Stop the profiler.
if profiler:
profiler.stop()
Expand Down
42 changes: 42 additions & 0 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,3 +1065,45 @@ def test_local_path_loading(tmpdir):

# Check that the models are the same
assert _compare_models(model1.model, model2.model)


@pytest.mark.parametrize(
"finetuning_strategy, embedding_noise",
[
pytest.param(None, 0, id="None_without_noise"),
pytest.param(None, 5, id="None_with_noise"),
pytest.param("lora", 0, id="lora_without_noise"),
pytest.param("lora", 5, id="lora_with_noise"),
],
)
def test_llm_finetuning_with_embedding_noise(
tmpdir,
csv_filename,
finetuning_strategy,
embedding_noise,
):
train_df, prediction_df, config = _prepare_finetuning_test(csv_filename, finetuning_strategy, LOCAL_BACKEND, {})

# Add embedding noise
if embedding_noise:
config["model_parameters"] = {"neftune_noise_alpha": embedding_noise}

model = LudwigModel(config)

if embedding_noise:
assert model.config_obj.model_parameters.neftune_noise_alpha == embedding_noise

output_directory: str = str(tmpdir)
model_directory: str = pathlib.Path(output_directory) / "api_experiment_run" / "model"
model.train(dataset=train_df, output_directory=output_directory, skip_save_processed_input=False)

# Make sure we can load the saved model and then use it for predictions
model = LudwigModel.load(str(model_directory), backend=LOCAL_BACKEND)

base_model = LLM(ModelConfig.from_dict(config))
assert not _compare_models(base_model, model.model) # noqa F821

preds, _ = model.predict(dataset=prediction_df, output_directory=output_directory)
preds = convert_preds(preds)

assert preds
47 changes: 46 additions & 1 deletion tests/ludwig/utils/test_llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
import torch
from transformers import AutoConfig
from transformers import AutoConfig, AutoModelForCausalLM

from ludwig.constants import LOGITS, PREDICTIONS, PROBABILITIES
from ludwig.modules.training_hooks import NEFTuneHook
from ludwig.utils.llm_utils import (
add_left_padding,
create_attention_mask,
Expand Down Expand Up @@ -298,3 +299,47 @@ def test_get_realigned_target_and_prediction_tensors_for_inference(tokenizer):
assert torch.equal(updated_predictions[of_name][LOGITS][0][-1], torch.zeros(vocab_size))
assert torch.equal(updated_predictions[of_name][LOGITS][0][-2], torch.zeros(vocab_size))
assert not torch.equal(updated_predictions[of_name][LOGITS][0][-3], torch.zeros(vocab_size))


def _setup_models_for_neftune():
module_without_hook = AutoModelForCausalLM.from_pretrained(TEST_MODEL_NAME)
module_with_hook = AutoModelForCausalLM.from_pretrained(TEST_MODEL_NAME)

# Only module_with_hook should have the NEFTuneHook
neftune_hook = NEFTuneHook(neftune_noise_alpha=5)
module_with_hook = neftune_hook.activate_hook(module_with_hook)

return module_without_hook, module_with_hook


def _forward_pass_and_assert_neftune_hook(module_without_hook, module_with_hook, mode):
assert module_with_hook.get_input_embeddings()._forward_hooks
assert not module_without_hook.get_input_embeddings()._forward_hooks

if mode == "train":
module_without_hook.train()
module_with_hook.train()
elif mode == "eval":
module_without_hook.eval()
module_with_hook.eval()

input_tensor = torch.tensor([[1, 2, 3]])
output_tensor_with_noise = module_with_hook.get_input_embeddings()(input_tensor)
output_tensor_without_noise = module_without_hook.get_input_embeddings()(input_tensor)

if mode == "train":
assert not torch.equal(output_tensor_with_noise, output_tensor_without_noise)
elif mode == "eval":
assert torch.equal(output_tensor_with_noise, output_tensor_without_noise)


def test_neftune_hook_without_noise_alpha_train_mode():
"""Test that the NEFTuneHook is only applied when the module is in training mode."""
module_without_hook, module_with_hook = _setup_models_for_neftune()
_forward_pass_and_assert_neftune_hook(module_without_hook, module_with_hook, mode="train")


def test_neftune_hook_without_noise_alpha_eval_mode():
"""Test that the NEFTuneHook is not applied when the module is in eval mode."""
module_without_hook, module_with_hook = _setup_models_for_neftune()
_forward_pass_and_assert_neftune_hook(module_without_hook, module_with_hook, mode="eval")
Loading