Skip to content

Commit

Permalink
ENH Raise error when applying modules_to_save on tuner layer (#2028)
Browse files Browse the repository at this point in the history
Relates to #2027

Normally, when selecting the layers for fine-tuning, PEFT already
ensures that the same layer is not targeted for both parameter-efficient
fine-tuning (e.g. LoRA layer) and full fine-tuning (via
modules_to_save), as that makes no sense.

However, there is a loophole when the modules_to_save is applied ex
post. This happens for instance when having a task type like sequence
classification, where PEFT will automatically add the classfication head
to modules_to_save for the user. This loophole is now closed by adding a
check to ModulesToSaveWrapper that validates that the targeted layer is
not a tuner layer.

This does not fully resolve #2027 but will raise an early error in the
future to avoid confusion.

On top of this, the error message inside of
ModulesToSaveWrapper.check_module has been slightly adjusted.
Previously, the class name would be used, which can be confusing. E.g.
for LoRA, the class name of the linear LoRA layer is just "Linear",
which looks the same as nn.Linear. Therefore, the full name is now
shown.
  • Loading branch information
BenjaminBossan committed Aug 22, 2024
1 parent 8fcb195 commit f3c7c6e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
10 changes: 9 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,15 @@ def check_module(self):
# ModuleList, even though their forward methods cannot be called
forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList)
if isinstance(self.original_module, forbidden_classes):
cls_name = self.original_module.__class__.__name__
cls_name = self.original_module.__class__
raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}")

# local import to avoid circular import
from peft.tuners.tuners_utils import BaseTunerLayer

if isinstance(self.original_module, BaseTunerLayer):
# e.g. applying modules_to_save to a lora layer makes no sense
cls_name = self.original_module.__class__
raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}")

@property
Expand Down
25 changes: 1 addition & 24 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
get_peft_model,
)
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper, infer_device
from peft.utils import infer_device

from .testing_common import PeftCommonTester
from .testing_utils import get_state_dict, require_non_cpu
Expand Down Expand Up @@ -1530,29 +1530,6 @@ def test_adapter_name_makes_no_difference(self, config0):
assert torch.allclose(output_custom1, output_custom2)
assert torch.allclose(output_default, output_custom1)

@parameterized.expand(["merge_and_unload", "unload"])
def test_double_wrapping_merge_and_unload(self, method):
# see issue #1485
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("hf-internal-testing/tiny-random-RobertaModel")
config = LoraConfig(task_type="TOKEN_CLS", target_modules="all-linear")
model = get_peft_model(model, config)

# first check that double-wrapping happened
# Note: this may get fixed in a future PR, in which case this test can be removed
assert isinstance(model.base_model.model.classifier, ModulesToSaveWrapper)
assert hasattr(model.base_model.model.classifier.original_module, "lora_A")
assert hasattr(model.base_model.model.classifier.modules_to_save.default, "lora_A")

# after unloading, despite double wrapping, the classifier module should be a normal nn.Linear layer
if method == "merge_and_unload":
unloaded = model.merge_and_unload()
else:
unloaded = model.unload()

assert isinstance(unloaded.classifier, nn.Linear)

def test_gpt2_dora_merge_and_unload(self):
# see https://github.com/huggingface/peft/pull/1588#discussion_r1537914207
model = AutoModelForCausalLM.from_pretrained("gpt2")
Expand Down
19 changes: 18 additions & 1 deletion tests/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import torch
from torch import nn
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification

from peft import LoraConfig, get_peft_model

Expand Down Expand Up @@ -76,6 +76,23 @@ def test_modules_to_save_targets_module_dict_raises(cls):
get_peft_model(model=model, peft_config=peft_config)


def test_modules_to_save_targets_tuner_layer_raises():
# See e.g. issue 2027
# Prevent users from (accidentally) targeting the same layer both with a tuner and modules_to_save. Normally, PEFT
# will not target the same layer with both a tuner and ModulesToSaveWrapper. However, if modules_to_save is
# automatically inferred, e.g. when using AutoModelForSequenceClassification, the ModulesToSaveWrapper is applied ex
# post, which can lead to the double wrapping.
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
model = AutoModelForSequenceClassification.from_pretrained(model_id)

# Note: target_modules="all-linear" would also work and is closer to the original issue, but let's explicitly target
# "score" here in case that "all-linear" will be fixed to no longer target the score layer.
peft_config = LoraConfig(target_modules=["score"], task_type="SEQ_CLS")
msg = "modules_to_save cannot be applied to modules of type"
with pytest.raises(TypeError, match=msg):
get_peft_model(model, peft_config)


def test_get_peft_model_revision_warning(tmp_path):
base_model_id = "peft-internal-testing/tiny-random-BertModel"
base_revision = "v2.0.0"
Expand Down

0 comments on commit f3c7c6e

Please sign in to comment.