Skip to content

Commit

Permalink
Condense long text in plot axis (pytorch#1349)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1349

Captum will display complete plot axis labels, which makes the plot unreadable if you have longer segments than words; cap the max length at 50 characters.

Reviewed By: cyrjano

Differential Revision: D62758379

fbshipit-source-id: 96260b6e0101033b7ff580568902c7cdd54a8a46
  • Loading branch information
csauper authored and facebook-github-bot committed Sep 17, 2024
1 parent 6636f4d commit 70619a6
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pyre-strict
from copy import copy

from textwrap import shorten

from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -103,7 +105,10 @@ def plot_token_attr(
cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom")

# Show all ticks and label them with the respective list entries.
ax.set_xticks(np.arange(data.shape[1]), labels=self.input_tokens)
shortened_tokens = [
shorten(t, width=50, placeholder="...") for t in self.input_tokens
]
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens)

# Let the horizontal axes labeling appear on top.
Expand Down Expand Up @@ -149,7 +154,10 @@ def plot_seq_attr(

data = self.seq_attr.cpu().numpy()

ax.set_xticks(range(data.shape[0]), labels=self.input_tokens)
shortened_tokens = [
shorten(t, width=50, placeholder="...") for t in self.input_tokens
]
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)

ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

Expand Down

0 comments on commit 70619a6

Please sign in to comment.