Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Issue with using rank_pattern and alpha_pattern together in LoraConfig #2195

Merged
merged 4 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@

import math
import operator
import re
import warnings
from contextlib import contextmanager
from dataclasses import asdict, replace
from enum import Enum
from functools import partial, reduce
from itertools import chain
from typing import Literal, Optional

import torch
Expand All @@ -45,6 +43,7 @@
get_quantization_config,
)
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties
from peft.utils.other import get_pattern_key

from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
Expand Down Expand Up @@ -186,10 +185,10 @@ def _create_and_replace(
raise ValueError("Current Key shouldn't be `None`")

# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(rf".*\.{key}$", current_key), pattern_keys), current_key)
r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)
r_key = get_pattern_key(lora_config.rank_pattern.keys(), current_key)
alpha_key = get_pattern_key(lora_config.alpha_pattern.keys(), current_key)
r = lora_config.rank_pattern.get(r_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha)

kwargs = {
"r": r,
Expand Down
6 changes: 6 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
import inspect
import os
import re
import warnings
from contextlib import nullcontext
from typing import Any, Optional
Expand Down Expand Up @@ -716,3 +717,8 @@ def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Option
)

return exists


def get_pattern_key(pattern_keys, key_to_match):
"""Match a substring of key_to_match in pattern keys"""
return next(filter(lambda key: re.match(rf".*\.{key}$", key_to_match), pattern_keys), key_to_match)
27 changes: 27 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,33 @@ def test_lora_scaling_default(self):
assert model.embed.scaling["default"] == expected_scaling
assert model.conv2d.scaling["default"] == expected_scaling

# testcase for bugfix for issue 2194
def test_pattern_override(self):
torch.manual_seed(0)

layer = self.get_model()
model = nn.Sequential(layer, layer)
config = LoraConfig(
target_modules=["linear"],
lora_alpha=1,
r=8,
use_rslora=False,
rank_pattern={"linear": 8},
alpha_pattern={"0.linear": 2},
)
model = get_peft_model(model, config)
scaling_with_rank_pattern = model.model[0].linear.scaling

layer = self.get_model()
model = nn.Sequential(layer, layer)
config = LoraConfig(
target_modules=["linear"], lora_alpha=1, r=8, use_rslora=False, alpha_pattern={"0.linear": 2}
)
model = get_peft_model(model, config)
scaling_without_rank_pattern = model.model[0].linear.scaling

assert scaling_with_rank_pattern == scaling_without_rank_pattern

def test_lora_pissa_linear_init_default(self, data):
model = self.get_model()
output = model(data)[0]
Expand Down
Loading