Skip to content

Commit

Permalink
Add unit test verifying compatibility with huggingface models (pytorc…
Browse files Browse the repository at this point in the history
…h#1352)

Summary:
Pull Request resolved: pytorch#1352

Our current unit tests for LLM Attribution use mocked models which are similar to huggingface transformer models (e.g. Llama, Llama2), but may have some unexpected differences such as [this](https://discuss.pytorch.org/t/trying-to-explain-zephyr-generative-llm/195262/3?fbclid=IwZXh0bgNhZW0CMTEAAR3REGbJsdhbNqG5LAyQ9_2J-82nPmNjt5avVyvNw-l8SMTWVXfI2DqIE8w_aem_GRP8EzELKtqDXDMZmox3Uw). To validate coverage and ensure compatibility with future changes to models, we would like to add tests using huggingface models directly and validate compatibility with LLM Attribution, which will help us quickly catch any breaking changes.

So far we only test for model type `LlamaForCausalLM`

Differential Revision: D62894898
  • Loading branch information
DianjingLiu authored and facebook-github-bot committed Sep 18, 2024
1 parent 70619a6 commit 53ca7ff
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
1 change: 1 addition & 0 deletions scripts/install_via_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fi
# install other deps
conda install -q -y pytest ipywidgets ipython scikit-learn parameterized werkzeug==2.2.2
conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress
conda install -q -y transformers

# install captum
python setup.py develop
2 changes: 2 additions & 0 deletions scripts/install_via_pip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ fi
if [[ $DEPLOY == true ]]; then
pip install beautifulsoup4 ipython nbconvert==5.6.1 --progress-bar off
fi

pip install transformers --progress-bar off
94 changes: 94 additions & 0 deletions tests/attr/test_llm_attr_hf_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3


from typing import cast, Dict, Optional, Type

import torch
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.llm_attr import LLMAttribution
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.interpretable_input import TextTemplateInput
from parameterized import parameterized, parameterized_class
from tests.helpers import BaseTest
from torch import Tensor

HAS_HF = True
try:
# pyre-fixme[21]: Could not find a module corresponding to import `transformers`
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
HAS_HF = False


@parameterized_class(
("device", "use_cached_outputs"),
(
[("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)]
if torch.cuda.is_available()
else [("cpu", True), ("cpu", False)]
),
)
class TestLLMAttrHFCompatibility(BaseTest):
# pyre-fixme[13]: Attribute `device` is never initialized.
device: str
# pyre-fixme[13]: Attribute `use_cached_outputs` is never initialized.
use_cached_outputs: bool

def setUp(self) -> None:
if not HAS_HF:
self.skipTest("transformers package not found, skipping tests")
super().setUp()

# pyre-fixme[56]: Pyre was not able to infer the type of argument `comprehension
@parameterized.expand(
[
(
AttrClass,
delta,
n_samples,
)
for AttrClass, delta, n_samples in zip(
(FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass
(0.001, 0.001, 0.001), # delta
(None, 1000, None), # n_samples
)
]
)
def test_llm_attr_hf_compatibility(
self,
AttrClass: Type[PerturbationAttribution],
delta: float,
n_samples: Optional[int],
) -> None:
attr_kws: Dict[str, int] = {}
if n_samples is not None:
attr_kws["n_samples"] = n_samples

tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)
llm = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)

llm.to(self.device)
llm.eval()
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_attr.attribute(
inp,
"m n o p q",
use_cached_outputs=self.use_cached_outputs,
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
# for 4th positional argument, expected
# `Optional[typing.Callable[..., typing.Any]]` but got `int`.
**attr_kws, # type: ignore
)
self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
self.assertEqual(len(res.output_tokens), 5)
self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)

0 comments on commit 53ca7ff

Please sign in to comment.