Skip to content

Commit

Permalink
ENH: Faster adapter loading if there are a lot of target modules (#2045)
Browse files Browse the repository at this point in the history
This is an optimization to reduce the number of entries in the
target_modules list. The reason is that in some circumstances,
target_modules can contain hundreds of entries. Since each target module
is checked against each module of the net (which can be thousands), this
can become quite expensive when many adapters are being added. Often,
the target_modules can be condensed in such a case, which speeds up the
process.

A context in which this can happen is when diffusers loads non-PEFT
LoRAs. As there is no meta info on target_modules in that case, they are
just inferred by listing all keys from the state_dict, which can be
quite a lot. See: huggingface/diffusers#9297

As shown there the speed improvements for loading many diffusers LoRAs
can be substantial. When loading 30 adapters, the time would go up from
0.6 sec per adapter to 3 sec per adapter. With this fix, the time goes
up from 0.6 sec per adapter to 1 sec per adapter.

As there is a small chance for undiscovered bugs, we apply this
optimization only if the list of target_modules is sufficiently big.
  • Loading branch information
BenjaminBossan committed Sep 2, 2024
1 parent 679bcd8 commit 01275b4
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 2 deletions.
108 changes: 107 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
from transformers.pytorch_utils import Conv1D

from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils.constants import DUMMY_MODEL_CONFIG, DUMMY_TARGET_MODULES, EMBEDDING_LAYER_NAMES, SEQ_CLS_HEAD_NAMES
from peft.utils.constants import (
DUMMY_MODEL_CONFIG,
DUMMY_TARGET_MODULES,
EMBEDDING_LAYER_NAMES,
MIN_TARGET_MODULES_FOR_OPTIMIZATION,
SEQ_CLS_HEAD_NAMES,
)
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
Expand Down Expand Up @@ -433,6 +439,26 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d
# update peft_config.target_modules if required
peft_config = _maybe_include_all_linear_layers(peft_config, model)

# This is an optimization to reduce the number of entries in the target_modules list. The reason is that in some
# circumstances, target_modules can contain hundreds of entries. Since each target module is checked against
# each module of the net (which can be thousands), this can become quite expensive when many adapters are being
# added. Often, the target_modules can be condensed in such a case, which speeds up the process.
# A context in which this can happen is when diffusers loads non-PEFT LoRAs. As there is no meta info on
# target_modules in that case, they are just inferred by listing all keys from the state_dict, which can be
# quite a lot. See: https://github.com/huggingface/diffusers/issues/9297
# As there is a small chance for undiscovered bugs, we apply this optimization only if the list of
# target_modules is sufficiently big.
if (
isinstance(peft_config.target_modules, (list, set))
and len(peft_config.target_modules) >= MIN_TARGET_MODULES_FOR_OPTIMIZATION
):
names_no_target = [
name for name in key_list if not any(name.endswith(suffix) for suffix in peft_config.target_modules)
]
new_target_modules = _find_minimal_target_modules(peft_config.target_modules, names_no_target)
if len(new_target_modules) < len(peft_config.target_modules):
peft_config.target_modules = new_target_modules

for key in key_list:
# Check for modules_to_save in case
if _check_for_modules_to_save and any(
Expand Down Expand Up @@ -781,6 +807,86 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device)


def _find_minimal_target_modules(
target_modules: list[str] | set[str], other_module_names: list[str] | set[str]
) -> set[str]:
"""Find the minimal set of target modules that is sufficient to separate them from the other modules.
Sometimes, a very large list of target_modules could be passed, which can slow down loading of adapters (e.g. when
loaded from diffusers). It may be possible to condense this list from hundreds of items to just a handful of
suffixes that are sufficient to distinguish the target modules from the other modules.
Example:
```py
>>> from peft.tuners.tuners_utils import _find_minimal_target_modules
>>> target_modules = [f"model.decoder.layers.{i}.self_attn.q_proj" for i in range(100)]
>>> target_modules += [f"model.decoder.layers.{i}.self_attn.v_proj" for i in range(100)]
>>> other_module_names = [f"model.encoder.layers.{i}.self_attn.k_proj" for i in range(100)]
>>> _find_minimal_target_modules(target_modules, other_module_names)
{"q_proj", "v_proj"}
```
Args:
target_modules (`list[str]` | `set[str]`):
The list of target modules.
other_module_names (`list[str]` | `set[str]`):
The list of other module names. They must not overlap with the target modules.
Returns:
`set[str]`:
The minimal set of target modules that is sufficient to separate them from the other modules.
Raises:
ValueError:
If `target_modules` is not a list or set of strings or if it contains an empty string. Also raises an error
if `target_modules` and `other_module_names` contain common elements.
"""
if isinstance(target_modules, str) or not target_modules:
raise ValueError("target_modules should be a list or set of strings.")

target_modules = set(target_modules)
if "" in target_modules:
raise ValueError("target_modules should not contain an empty string.")

other_module_names = set(other_module_names)
if not target_modules.isdisjoint(other_module_names):
msg = (
"target_modules and other_module_names contain common elements, this should not happen, please "
"open a GitHub issue at https://github.com/huggingface/peft/issues with the code to reproduce this issue"
)
raise ValueError(msg)

# it is assumed that module name parts are separated by a "."
def generate_suffixes(s):
parts = s.split(".")
return [".".join(parts[i:]) for i in range(len(parts))][::-1]

# Create a reverse lookup for other_module_names to quickly check suffix matches
other_module_suffixes = {suffix for item in other_module_names for suffix in generate_suffixes(item)}

# Find all potential suffixes from target_modules
target_modules_suffix_map = {item: generate_suffixes(item) for item in target_modules}

# Initialize a set for required suffixes
required_suffixes = set()

for item, suffixes in target_modules_suffix_map.items():
# Go through target_modules items, shortest suffixes first
for suffix in suffixes:
# If the suffix is already in required_suffixes or matches other_module_names, skip it
if suffix in required_suffixes or suffix in other_module_suffixes:
continue
# Check if adding this suffix covers the item
if not any(item.endswith(req_suffix) for req_suffix in required_suffixes):
required_suffixes.add(suffix)
break

if not required_suffixes:
return set(target_modules)
return required_suffixes


def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
"""A helper method to check if the passed module's key name matches any of the target modules in the adapter_config.
Expand Down
6 changes: 6 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,9 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
TOKENIZER_CONFIG_NAME = "tokenizer_config.json"
DUMMY_TARGET_MODULES = "dummy-target-modules"
DUMMY_MODEL_CONFIG = {"model_type": "custom"}

# If users specify more than this number of target modules, we apply an optimization to try to reduce the target modules
# to a minimal set of suffixes, which makes loading faster. We only apply this when exceeding a certain size since
# otherwise there is no point in optimizing and there is a small chance of bugs in the optimization algorithm, so no
# point in taking unnecessary risks. See #2045 for more context.
MIN_TARGET_MODULES_FOR_OPTIMIZATION = 20
135 changes: 134 additions & 1 deletion tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,19 @@
get_model_status,
get_peft_model,
)
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
_maybe_include_all_linear_layers,
check_target_module_exists,
inspect_matched_modules,
)
from peft.tuners.tuners_utils import (
_find_minimal_target_modules as find_minimal_target_modules,
)
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device
from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft.utils.constants import DUMMY_MODEL_CONFIG, MIN_TARGET_MODULES_FOR_OPTIMIZATION

from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu

Expand Down Expand Up @@ -1149,3 +1153,132 @@ def test_no_warn_for_no_target_module_merge(self, recwarn):
model_no_target_module = self._get_peft_model(tie_word_embeddings=True, target_module="q_proj")
model_no_target_module.merge_and_unload()
assert not self._is_warn_triggered(recwarn.list, self.warn_end_merge)


class TestFindMinimalTargetModules:
@pytest.mark.parametrize(
"target_modules, other_module_names, expected",
[
(["bar"], [], {"bar"}),
(["foo"], ["bar"], {"foo"}),
(["1.foo", "2.foo"], ["3.foo", "4.foo"], {"1.foo", "2.foo"}),
# Could also return "bar.baz" but we want the shorter one
(["bar.baz"], ["foo.bar"], {"baz"}),
(["1.foo", "2.foo", "bar.baz"], ["3.foo", "bar.bla"], {"1.foo", "2.foo", "baz"}),
# Case with longer suffix chains and nested suffixes
(["a.b.c", "d.e.f", "g.h.i"], ["j.k.l", "m.n.o"], {"c", "f", "i"}),
(["a.b.c", "d.e.f", "g.h.i"], ["a.b.x", "d.x.f", "x.h.i"], {"c", "e.f", "g.h.i"}),
# Case with multiple items that can be covered by a single suffix
(["foo.bar.baz", "qux.bar.baz"], ["baz.bar.foo"], {"baz"}),
# Realistic examples
# Only match k_proj
(
["model.decoder.layers.{i}.self_attn.k_proj" for i in range(12)],
(
["model.decoder.layers.{i}.self_attn" for i in range(12)]
+ ["model.decoder.layers.{i}.self_attn.v_proj" for i in range(12)]
+ ["model.decoder.layers.{i}.self_attn.q_proj" for i in range(12)]
),
{"k_proj"},
),
# Match all k_proj except the one in layer 5 => no common suffix
(
["model.decoder.layers.{i}.self_attn.k_proj" for i in range(12) if i != 5],
(
["model.decoder.layers.5.self_attn.k_proj"]
+ ["model.decoder.layers.{i}.self_attn" for i in range(12)]
+ ["model.decoder.layers.{i}.self_attn.v_proj" for i in range(12)]
+ ["model.decoder.layers.{i}.self_attn.q_proj" for i in range(12)]
),
{"{i}.self_attn.k_proj" for i in range(12) if i != 5},
),
],
)
def test_find_minimal_target_modules(self, target_modules, other_module_names, expected):
# check all possible combinations of list and set
result = find_minimal_target_modules(target_modules, other_module_names)
assert result == expected

result = find_minimal_target_modules(set(target_modules), other_module_names)
assert result == expected

result = find_minimal_target_modules(target_modules, set(other_module_names))
assert result == expected

result = find_minimal_target_modules(set(target_modules), set(other_module_names))
assert result == expected

def test_find_minimal_target_modules_empty_raises(self):
with pytest.raises(ValueError, match="target_modules should be a list or set of strings"):
find_minimal_target_modules([], ["foo"])

with pytest.raises(ValueError, match="target_modules should be a list or set of strings"):
find_minimal_target_modules(set(), ["foo"])

def test_find_minimal_target_modules_contains_empty_string_raises(self):
target_modules = ["", "foo", "bar.baz"]
other_module_names = ["bar"]
with pytest.raises(ValueError, match="target_modules should not contain an empty string"):
find_minimal_target_modules(target_modules, other_module_names)

def test_find_minimal_target_modules_string_raises(self):
target_modules = "foo"
other_module_names = ["bar"]
with pytest.raises(ValueError, match="target_modules should be a list or set of strings"):
find_minimal_target_modules(target_modules, other_module_names)

@pytest.mark.parametrize(
"target_modules, other_module_names",
[
(["foo"], ["foo"]),
(["foo.bar"], ["foo.bar"]),
(["foo.bar", "spam", "eggs"], ["foo.bar"]),
(["foo.bar", "spam"], ["foo.bar", "eggs"]),
(["foo.bar"], ["foo.bar", "spam", "eggs"]),
],
)
def test_find_minimal_target_modules_not_disjoint_raises(self, target_modules, other_module_names):
msg = (
"target_modules and other_module_names contain common elements, this should not happen, please "
"open a GitHub issue at https://github.com/huggingface/peft/issues with the code to reproduce this issue"
)
with pytest.raises(ValueError, match=msg):
find_minimal_target_modules(target_modules, other_module_names)

def test_get_peft_model_applies_find_target_modules(self):
# Check that when calling get_peft_model, the target_module optimization is indeed applied if the lenght of
# target_modules is big enough. The resulting model itself should be unaffected.
torch.manual_seed(0)
model_id = "facebook/opt-125m" # must be big enough for optimization to trigger
model = AutoModelForCausalLM.from_pretrained(model_id)

# base case: specify target_modules in a minimal fashion
config = LoraConfig(init_lora_weights=False, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, config)

# this list contains all targeted modules listed separately
big_target_modules = [name for name, module in model.named_modules() if isinstance(module, LoraLayer)]
# sanity check
assert len(big_target_modules) > MIN_TARGET_MODULES_FOR_OPTIMIZATION

# make a "checksum" of the model for comparison
model_check_sum_before = sum(p.sum() for p in model.parameters())

# strip prefix so that the names they can be used as new target_modules
prefix_to_strip = "base_model.model.model."
big_target_modules = [name[len(prefix_to_strip) :] for name in big_target_modules]

del model

torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(model_id)
# pass the big target_modules to config
config = LoraConfig(init_lora_weights=False, target_modules=big_target_modules)
model = get_peft_model(model, config)

# check that target modules have been condensed
assert model.peft_config["default"].target_modules == {"q_proj", "v_proj"}

# check that the resulting model is still the same
model_check_after = sum(p.sum() for p in model.parameters())
assert model_check_sum_before == model_check_after

0 comments on commit 01275b4

Please sign in to comment.