Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPT Tuner #2168

Merged
merged 28 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
92b9e1a
added CPT model to peft
tsachiblau Oct 22, 2024
e54d380
Merge branch 'huggingface:main' into main
tsachiblau Oct 22, 2024
023f071
Merge branch 'huggingface:main' into main
tsachiblau Oct 24, 2024
54cddaf
Merge branch 'huggingface:main' into main
tsachiblau Oct 25, 2024
2dfe70f
Added arXiv link to the paper, integrated CPT into testing framework,…
tsachiblau Oct 25, 2024
ba4b115
Merge branch 'huggingface:main' into main
tsachiblau Oct 25, 2024
f8c8317
Merge branch 'huggingface:main' into main
tsachiblau Oct 30, 2024
bd2fc70
config: Added config check in __post_init__. Removed redundant initia…
tsachiblau Oct 30, 2024
b01b214
Merge branch 'main' of https://github.com/tsachiblau/peft_CPT
tsachiblau Oct 30, 2024
6ed1723
Merge branch 'huggingface:main' into main
tsachiblau Nov 3, 2024
77bb0b9
tests: Updated test_cpt and testing_common as per the PR requirements.
tsachiblau Nov 3, 2024
dbcdedf
Created cpt.md in package_regerence. Updated the prompting.md file. a…
tsachiblau Nov 3, 2024
f7138d4
Merge branch 'huggingface:main' into main
tsachiblau Nov 5, 2024
0a5fb20
verifying that the model is causal LM
tsachiblau Nov 5, 2024
7206db5
Changed CPTModel to CPTEmbedding
tsachiblau Nov 5, 2024
24b0af9
merge with main branch
tsachiblau Nov 5, 2024
81ffa09
make style
tsachiblau Nov 7, 2024
130ec76
make style
tsachiblau Nov 7, 2024
70067d8
make style
tsachiblau Nov 7, 2024
9397314
make doc
tsachiblau Nov 8, 2024
249713c
Merge branch 'huggingface:main' into main
tsachiblau Nov 10, 2024
0a43473
Removed redundant checks
tsachiblau Nov 10, 2024
144f042
Fixed errors
tsachiblau Nov 13, 2024
97449da
merge with peft
tsachiblau Nov 13, 2024
dacb400
Minor code updates.
tsachiblau Nov 13, 2024
cc348a4
Minor code updates.
tsachiblau Nov 17, 2024
79959d1
Merge branch 'huggingface:main' into main
tsachiblau Nov 18, 2024
7eea892
Minor code updates.
tsachiblau Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
title: VB-LoRA
- local: package_reference/hra
title: HRA
- local: package_reference/cpt
title: CPT
- local: package_reference/bone
title: Bone

Expand Down
16 changes: 16 additions & 0 deletions docs/source/conceptual_guides/prompting.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,19 @@ Take a look at [P-tuning for sequence classification](../task_guides/ptuning-seq
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/peft/mpt-decomposition.png"/>
</div>
<small><a href="https://hf.co/papers/2103.10385">Prompt decomposition</a>.</small>


## Context-Aware Prompt Tuning (CPT)

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/peft/cpt.png"/>
</div>
<small>CPT optimizing only specific token embeddings while keeping the rest of the model frozen <a href="https://huggingface.co/papers/2410.17222">(image source)</a>.</small>

[Context-Aware Prompt Tuning (CPT)](https://huggingface.co/papers/2410.17222) is designed to enhance few-shot classification by refining only context embeddings.
This approach combines ideas from In-Context Learning (ICL), Prompt Tuning (PT), and adversarial optimization, focusing on making model adaptation both parameter-efficient and effective.
In CPT, only specific context token embeddings are optimized, while the rest of the model remains frozen.
To prevent overfitting and maintain stability, CPT uses controlled perturbations to limit the allowed changes to context embeddings within a defined range.
Additionally, to address the phenomenon of recency bias—where examples near the end of the context tend to be prioritized over earlier ones—CPT applies a decay loss factor.

Take a look at [Context-Aware Prompt Tuning for few-shot classification](../task_guides/cpt-few-shot-classification) for a step-by-step guide on how to train a model with CPT.
31 changes: 31 additions & 0 deletions docs/source/package_reference/cpt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Context-Aware Prompt Tuning (CPT)

[Context-aware Prompt Tuning: Advancing In-Context Learning with Adversarial Methods (CPT)](https://huggingface.co/papers/2410.17222) combines In-Context Learning (ICL) with Prompt Tuning (PT) and adversarial optimization to improve few-shot learning by refining context embeddings. CPT optimizes only context tokens, which minimizes overfitting and enhances performance on classification tasks.

The abstract from the paper is:

*Traditional fine-tuning is effective but computationally intensive, as it requires updating billions of parameters. CPT, inspired by ICL, PT, and adversarial attacks, refines context embeddings in a parameter-efficient manner. By optimizing context tokens and applying a controlled gradient descent, CPT achieves superior accuracy across various few-shot classification tasks, showing significant improvement over existing methods such as LoRA, PT, and ICL.*

## CPTConfig

[[autodoc]] tuners.cpt.config.CPTConfig

## CPTEmbedding

[[autodoc]] tuners.cpt.model.CPTEmbedding

2 changes: 2 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
VBLoRAConfig,
get_eva_state_dict,
initialize_lora_eva_weights,
CPTEmbedding,
CPTConfig,
BoneConfig,
BoneModel,
)
Expand Down
4 changes: 4 additions & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
BOFTModel,
BoneConfig,
BoneModel,
CPTConfig,
CPTEmbedding,
FourierFTConfig,
FourierFTModel,
HRAConfig,
Expand Down Expand Up @@ -106,6 +108,7 @@
"XLORA": XLoraConfig,
"HRA": HRAConfig,
"VBLORA": VBLoRAConfig,
"CPT": CPTConfig,
"BONE": BoneConfig,
}

Expand All @@ -124,6 +127,7 @@
"XLORA": XLoraModel,
"HRA": HRAModel,
"VBLORA": VBLoRAModel,
"CPT": CPTEmbedding,
"BONE": BoneModel,
}

Expand Down
68 changes: 68 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AdaptionPromptModel,
BOFTModel,
BoneModel,
CPTEmbedding,
FourierFTModel,
HRAModel,
IA3Model,
Expand Down Expand Up @@ -105,6 +106,7 @@
PeftType.XLORA: XLoraModel,
PeftType.HRA: HRAModel,
PeftType.VBLORA: VBLoRAModel,
PeftType.CPT: CPTEmbedding,
PeftType.BONE: BoneModel,
}

Expand Down Expand Up @@ -654,6 +656,8 @@ def _setup_prompt_encoder(self, adapter_name: str):
if any(getattr(module, "gradient_checkpointing", False) for module in self.get_base_model().modules()):
raise ValueError("Prefix tuning does not work with gradient checkpointing.")
prompt_encoder = PrefixEncoder(config)
elif config.peft_type == PeftType.CPT:
prompt_encoder = CPTEmbedding(config, self.word_embeddings)
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError("Not supported")

Expand Down Expand Up @@ -1746,6 +1750,8 @@ def forward(
# overwrite past_kv in kwargs
kwargs["past_key_values"] = self.get_prompt(batch_size)
return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
elif peft_config.peft_type == PeftType.CPT:
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
return self._cpt_forward(input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs)
else:
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand All @@ -1758,6 +1764,68 @@ def forward(
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

def _cpt_forward(
self, input_ids=None, inputs_embeds=None, peft_config=None, task_ids=None, batch_size=None, **kwargs
):
# Extract labels from kwargs
labels = kwargs.pop("labels")
device = [i.device for i in [input_ids, inputs_embeds, labels] if i is not None][0]
# Extract input_type_mask from kwargs and move it to the same device as labels
if "input_type_mask" in kwargs.keys():
input_type_mask = kwargs.pop("input_type_mask").to(device)
else:
if input_ids is None:
N_tokens = inputs_embeds.shape[1]
else:
N_tokens = input_ids.shape[1]
input_type_mask = torch.zeros((batch_size, N_tokens)).to(device)
input_type_mask[:, :] = 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_type_mask.fill_(4) would also work. Could you add a short comment on what "4" means here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 is the id for the tokens used for the loss calculation. I changed the code to
input_type_mask = torch.ones((batch_size, N_tokens)).to(device) * 4


if peft_config.cpt_prompt_init == "TEXT":
cpt_token_ids = peft_config.cpt_token_ids
cpt_tokens_type_mask = peft_config.cpt_tokens_type_mask
else:
cpt_token_ids = [0] * peft_config.num_virtual_tokens
cpt_tokens_type_mask = [0] * peft_config.num_virtual_tokens

# Generate embeddings if not provided
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# Get prompt and concatenate with input embeddings
prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
# If labels are provided, generate prefix labels and type mask
cpt_labels = None
if labels is not None:
# Generate prefix labels and concatenate with the input labels
prefix_labels = torch.Tensor(cpt_token_ids).long().view(1, -1)
prefix_labels = prefix_labels.repeat(batch_size, 1).to(labels.device)
cpt_labels = torch.cat((prefix_labels, labels), dim=1)
# Generate prefix type mask and shift input type mask values to avoid conflicts
prefix_type_mask = torch.Tensor(cpt_tokens_type_mask).long().view(1, -1)
prefix_type_mask = prefix_type_mask.repeat(batch_size, 1).to(labels.device)
adjusted_input_type_mask = input_type_mask
adjusted_input_type_mask[adjusted_input_type_mask > 0] += prefix_type_mask.max()
# Concatenate prefix and shifted input type masks
cpt_type_mask = torch.cat((prefix_type_mask, adjusted_input_type_mask), dim=1)
# Identify valid label positions and mask invalid ones with -100
labels_idx = (cpt_type_mask > 0) & (cpt_type_mask % 4 == 0)
cpt_labels[~labels_idx] = -100
# Update kwargs with the modified labels

kwargs["labels"] = cpt_labels
# Pass the modified inputs to the base model
base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs)
if labels is None:
return base_model_output
else:
# Calculate the loss using the custom CPT loss function
base_model_output = CPTEmbedding.calculate_loss(
base_model_output, cpt_labels, cpt_type_mask, self.peft_config["default"]
)
return base_model_output

def generate(self, *args, **kwargs):
peft_config = self.active_peft_config
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@
from .xlora import XLoraConfig, XLoraModel
from .hra import HRAConfig, HRAModel
from .vblora import VBLoRAConfig, VBLoRAModel
from .cpt import CPTConfig, CPTEmbedding
from .bone import BoneConfig, BoneModel
20 changes: 20 additions & 0 deletions src/peft/tuners/cpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .config import CPTConfig
from .model import CPTEmbedding


__all__ = ["CPTConfig", "CPTEmbedding"]
134 changes: 134 additions & 0 deletions src/peft/tuners/cpt/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
from dataclasses import dataclass, field
from typing import Optional

from peft.config import PromptLearningConfig
from peft.utils import PeftType


class CPTPromptInit(str, enum.Enum):
"""Enum for specifying the initialization method for CPT."""

TEXT = "TEXT" # Initialize using text-based embeddings.
RANDOM = "RANDOM" # Initialize randomly.


@dataclass
class CPTConfig(PromptLearningConfig):
"""
CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT).

This class introduces additional parameters required for CPT, such as:
- Token type masks
- Prompt tuning initialization
- Loss weighting
- Projection settings

For more details, see the paper: https://arxiv.org/abs/2410.17222
"""

# Token-related configurations
cpt_token_ids: Optional[list[int]] = field(
default=None, metadata={"help": "Tensor of token IDs used for CPT prompts."}
)
cpt_mask: Optional[list[int]] = field(default=None, metadata={"help": "Tensor mask applied to CPT tokens."})
cpt_tokens_type_mask: Optional[list[int]] = field(
default=None, metadata={"help": "Mask indicating the type of each CPT token."}
)

# Prompt tuning initialization method
cpt_prompt_init: Optional[str] = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using Literal["TEXT", "RANDOM"] as type annotation would be a bit more precise.

Copy link
Contributor Author

@tsachiblau tsachiblau Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is still relevant.

It already exists in the code.

Using Literal["TEXT", "RANDOM"] as type annotation would be a bit more precise.

I changed it.

default="TEXT", metadata={"help": "Initialization method: 'TEXT' for embedding-based, 'RANDOM' for random."}
)

# Loss-related configurations
opt_weighted_loss_type: Optional[str] = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could change the type to Literal["none", "decay"] to be more precise. Also, remove the Optional, as it implies that None is a valid option, which it is not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still relevant

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion is to change the type annotation to Literal["none", "decay"].

default="none", metadata={"help": "Type of weighted loss: 'none' or 'decay'."}
)
opt_loss_decay_factor: Optional[float] = field(
default=1.0, metadata={"help": "Factor for exponential decay in loss weighting."}
)

# Projection-related configurations
opt_projection_epsilon: Optional[float] = field(
default=0.1, metadata={"help": "Epsilon value for input projection."}
)
opt_projection_format_epsilon: Optional[float] = field(
default=0.1, metadata={"help": "Epsilon value for format projection."}
)

# Tokenizer configuration
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`"
},
)

# Virtual token configurations
num_virtual_tokens: int = field(default=0, metadata={"help": "Number of virtual tokens used in the prompt."})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 0 a sensible default for num_virtual_tokens?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having 0 as the default here makes little sense. WDYT about using a good default here, say, 10?


# CPT-specific static attributes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not supposed to be modified by the user, right? In that case, let's move them inside of __post_init__.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about this suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it ever make sense to let users pass these arguments? If not, I would remove them here and place them inside the __post_init__ method.

is_prompt_learning = True # Indicates that CPT is a prompt-learning method.
num_layers = None # Number of layers (optional, not always required).
token_dim = None # Dimension of token embeddings.
num_attention_heads = None # Number of attention heads (if applicable).
task_type = "CAUSAL_LM" # Specifies that CPT is used for causal language modeling.
num_transformer_submodules = 1 # Number of transformer submodules used.

def __post_init__(self):
"""
Post-initialization hook to set additional attributes after the config is initialized.
"""
self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT.
self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling.
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_token_ids is None:
self.cpt_token_ids = [0]
self.num_virtual_tokens = 1

if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_mask is None:
self.cpt_mask = [1 for _ in self.cpt_token_ids]

if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_tokens_type_mask is None:
self.cpt_tokens_type_mask = [1 for _ in self.cpt_token_ids]

if (self.cpt_prompt_init == CPTPromptInit.TEXT) and not (
len(self.cpt_token_ids) == len(self.cpt_mask) == len(self.cpt_tokens_type_mask) == self.num_virtual_tokens
):
raise ValueError(
f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', "
f"cpt_token_ids, cpt_mask and cpt_tokens_type_mask must have the same length."
)

if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_token_ids is not None:
raise ValueError(
f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " f"cpt_token_ids must be None."
)
if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_mask is not None:
raise ValueError(f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " f"cpt_mask must be None.")
if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_tokens_type_mask is not None:
raise ValueError(
f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " f"cpt_tokens_type_mask must be None."
)
if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.num_virtual_tokens == 0:
raise ValueError(
f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', "
f"num_virtual_tokens must be greater than zero."
)
if (self.cpt_prompt_init != CPTPromptInit.RANDOM) and (self.cpt_prompt_init != CPTPromptInit.TEXT):
raise ValueError("prompt_tuning_init must be 'RANDOM' or 'TEXT'")
Loading
Loading