Skip to content

Commit

Permalink
Fix to prefix tuning to fit transformers (#2096)
Browse files Browse the repository at this point in the history
See #869, #1962

Fix several issues caused by changes to cache in transformers. In
particular, past_key_values for prefix tuning is now converted to a
transformers Cache instance.

---------

Co-authored-by: Raushan Turganbay <[email protected]>
  • Loading branch information
BenjaminBossan and zucchini-nlp authored Oct 24, 2024
1 parent cff2a45 commit fb6108a
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 8 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"datasets",
"diffusers",
"scipy",
"protobuf",
"sentencepiece",
]

setup(
Expand Down
32 changes: 27 additions & 5 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from safetensors import safe_open
from safetensors.torch import save_file as safe_save_file
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers import Cache, DynamicCache, EncoderDecoderCache, PreTrainedModel
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

Expand Down Expand Up @@ -730,6 +730,18 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
past_key_values = post_process_fn(past_key_values)
elif peft_config.num_transformer_submodules == 1:
# Dont' apply this to encoder-decoder models and not to models requiring special processing.
# local import in case users use a very old transformers version
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
elif peft_config.num_transformer_submodules == 2 and self.base_model._supports_cache_class:
# Dont' apply this to encoder-decoder models that don't support new Cachc format yet
# If we don't apply this, prefix-tuning fails to update cross-attn cache
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values.cross_attention_cache = DynamicCache()
past_key_values.is_updated = {
layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache))
}
return past_key_values
else:
if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
Expand Down Expand Up @@ -2066,10 +2078,20 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
if peft_config.peft_type == PeftType.POLY:
model_kwargs["task_ids"] = task_ids
if model_kwargs.get("past_key_values", None) is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
batch_size = model_kwargs["decoder_input_ids"].shape[0]
past_key_values = self.get_prompt(batch_size)
model_kwargs["past_key_values"] = past_key_values
elif peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = model_kwargs.get("past_key_values", None)
cache_position = model_kwargs.get("cache_position", [None])
# check prefill stage
is_prefill_stage = (
# old cache implementation
(past_key_values is None)
# new cache implementation
or (isinstance(past_key_values, Cache) and (cache_position[0] == 0))
)
if is_prefill_stage:
batch_size = model_kwargs["decoder_input_ids"].shape[0]
new_past_key_values = self.get_prompt(batch_size)
model_kwargs["past_key_values"] = new_past_key_values

return model_kwargs

Expand Down
41 changes: 40 additions & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from unittest.mock import Mock, call, patch

import pytest
import torch
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)

from peft import (
AdaLoraConfig,
Expand Down Expand Up @@ -466,3 +474,34 @@ def test_prompt_learning_with_grouped_query_attention(self):
x = torch.tensor([[1, 2, 3]])
# does not raise
model(x)

def test_prefix_tuning_mistral(self):
# See issue 869, 1962
model_id = "hf-internal-testing/tiny-random-MistralForCausalLM"
base_model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
model = get_peft_model(base_model, peft_config)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

def process(samples):
tokenized = tokenizer(samples["quote"], truncation=True, max_length=128)
return tokenized

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(process, batched=True)

with tempfile.TemporaryDirectory() as tmp_dirname:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
num_train_epochs=1,
max_steps=5,
per_device_train_batch_size=4,
output_dir=tmp_dirname,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
4 changes: 2 additions & 2 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,8 +1601,8 @@ def get_output(model):

output_peft = get_output(peft_model)

# first check trivial case is not true that peft does not affect the output; for this to work, init_lora_weight
# must be False
# first check trivial case is not true that peft does not affect the output; for this to work, init_weight
# must be False (if the config supports it)
if isinstance(peft_model, StableDiffusionPipeline):
# for SD, check that most pixels have different values
assert (output_before != output_peft).float().mean() > 0.8
Expand Down

0 comments on commit fb6108a

Please sign in to comment.