Skip to content

Commit

Permalink
Add unit test verifying compatibility with huggingface models (#1352)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1352

Our current unit tests for LLM Attribution use mocked models which are similar to huggingface transformer models (e.g. Llama, Llama2), but may have some unexpected differences such as [this](https://discuss.pytorch.org/t/trying-to-explain-zephyr-generative-llm/195262/3?fbclid=IwZXh0bgNhZW0CMTEAAR3REGbJsdhbNqG5LAyQ9_2J-82nPmNjt5avVyvNw-l8SMTWVXfI2DqIE8w_aem_GRP8EzELKtqDXDMZmox3Uw). To validate coverage and ensure compatibility with future changes to models, we would like to add tests using huggingface models directly and validate compatibility with LLM Attribution, which will help us quickly catch any breaking changes.

So far we only test for model type `LlamaForCausalLM`

Differential Revision: D62894898
  • Loading branch information
DianjingLiu authored and facebook-github-bot committed Sep 18, 2024
1 parent 70619a6 commit 34a6279
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
1 change: 1 addition & 0 deletions scripts/install_via_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fi
# install other deps
conda install -q -y pytest ipywidgets ipython scikit-learn parameterized werkzeug==2.2.2
conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress
conda install -q -y transformers

# install captum
python setup.py develop
93 changes: 93 additions & 0 deletions tests/attr/test_llm_attr_hf_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/usr/bin/env python3

# pyre-strict

import unittest
from typing import cast, Dict, Optional, Type

import torch
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.llm_attr import LLMAttribution
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.interpretable_input import TextTemplateInput
from parameterized import parameterized, parameterized_class
from tests.helpers import BaseTest
from torch import Tensor

HAS_HF = True
try:
# pyre-fixme[21]: Could not find a module corresponding to import `transformers`
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
HAS_HF = False


@parameterized_class(
("device", "use_cached_outputs"),
(
[("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)]
if torch.cuda.is_available()
else [("cpu", True), ("cpu", False)]
),
)
class TestLLMAttr(BaseTest):
# pyre-fixme[13]: Attribute `device` is never initialized.
device: str
# pyre-fixme[13]: Attribute `use_cached_outputs` is never initialized.
use_cached_outputs: bool

# pyre-fixme[56]: Pyre was not able to infer the type of argument `comprehension
@parameterized.expand(
[
(
AttrClass,
delta,
n_samples,
)
for AttrClass, delta, n_samples in zip(
(FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass
(0.001, 0.001, 0.001), # delta
(None, 1000, None), # n_samples
)
]
)
def test_llm_attr(
self,
AttrClass: Type[PerturbationAttribution],
delta: float,
n_samples: Optional[int],
) -> None:
if not HAS_HF:
unittest.SkipTest("transformers package not found, skipping tests")
attr_kws: Dict[str, int] = {}
if n_samples is not None:
attr_kws["n_samples"] = n_samples

tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)
llm = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)

llm.to(self.device)
llm.eval()
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_attr.attribute(
inp,
"m n o p q",
use_cached_outputs=self.use_cached_outputs,
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
# for 4th positional argument, expected
# `Optional[typing.Callable[..., typing.Any]]` but got `int`.
**attr_kws, # type: ignore
)
self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
self.assertEqual(len(res.output_tokens), 5)
self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)

0 comments on commit 34a6279

Please sign in to comment.