Skip to content

Commit

Permalink
[PEFT] Set eval mode when loading PEFT adapter (huggingface#34509)
Browse files Browse the repository at this point in the history
* [PEFT] Set eval mode when loading PEFT adapter

Resolves huggingface#34469

When calling model.load_adapter to load a PEFT adapter, by default the
adapter should be set to eval mode. This is now correctly done. Users
can still pass is_trainable=True to load the adapter in training mode.

* Linter
  • Loading branch information
BenjaminBossan authored Nov 28, 2024
1 parent 5523e38 commit f4b674f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def load_adapter(
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
low_cpu_mem_usage: bool = False,
is_trainable: bool = False,
adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Expand Down Expand Up @@ -136,6 +137,9 @@ def load_adapter(
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
Requires PEFT version 0.13.0 or higher.
is_trainable (`bool`, *optional*, defaults to `False`):
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
used for inference.
adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method.
Expand Down Expand Up @@ -209,6 +213,7 @@ def load_adapter(
token=token,
**adapter_kwargs,
)
peft_config.inference_mode = not is_trainable

# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
Expand Down Expand Up @@ -258,6 +263,9 @@ def load_adapter(
if err_msg:
logger.warning(err_msg)

if peft_config.inference_mode:
self.eval()

# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
if (
(getattr(self, "hf_device_map", None) is not None)
Expand Down
43 changes: 43 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,46 @@ def test_peft_from_pretrained_missing_keys_warning(self):

msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
self.assertIn(msg, cl.out)

def test_peft_load_adapter_training_inference_mode_true(self):
"""
By default, when loading an adapter, the whole model should be in eval mode and no parameter should have
requires_grad=False.
"""
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)

with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
model.load_adapter(tmpdirname)
assert not any(p.requires_grad for p in model.parameters())
assert not any(m.training for m in model.modules())
del model

def test_peft_load_adapter_training_inference_mode_false(self):
"""
When passing is_trainable=True, the LoRA modules should be in training mode and their parameters should have
requires_grad=True.
"""
for model_id in self.peft_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)

with tempfile.TemporaryDirectory() as tmpdirname:
peft_model.save_pretrained(tmpdirname)
model = transformers_class.from_pretrained(peft_model.config._name_or_path)
model.load_adapter(tmpdirname, is_trainable=True)

for name, module in model.named_modules():
if len(list(module.children())):
# only check leaf modules
continue

if "lora_" in name:
assert module.training
assert all(p.requires_grad for p in module.parameters())
else:
assert not module.training
assert all(not p.requires_grad for p in module.parameters())

0 comments on commit f4b674f

Please sign in to comment.