Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow for attention caching during CoCa generation #502

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
54 changes: 44 additions & 10 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,15 @@ def _encode_image(self, images, normalize=True):
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs

def _encode_text(self, text, normalize=True, embed_cls=True):
def _encode_text(self, text, normalize=True, embed_cls=True, cache=None):
text = text[:, :-1] if embed_cls else text # make space for CLS token
text_latent, token_emb = self.text(text)
if cache is not None:
text_latent, token_emb, attentions = self.text(text, cache)
else:
text_latent, token_emb = self.text(text)
attentions = None
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
return text_latent, token_emb, attentions

def encode_image(self, images, normalize=True):
image_latent, _ = self._encode_image(images, normalize=normalize)
Expand All @@ -147,19 +151,22 @@ def encode_text(self, text, normalize=True, embed_cls=True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
return text_latent

def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None, cache=None):
#TODO: Fix encoder caching
text_latent, token_embs, text_attentions = self._encode_text(text, embed_cls=embed_cls, cache=None)
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]

logits = self.text_decoder(image_embs, token_embs)
logits, attentions, cross_attentions = self.text_decoder(image_embs, token_embs, cache=cache["dec"])
return {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"text_attentions": text_attentions,
"attentions": attentions,
"cross_attentions": cross_attentions,
"labels": labels,
"logit_scale": self.logit_scale.exp()
}
Expand All @@ -182,7 +189,8 @@ def generate(
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
fixed_output_length=False, # if True output.shape == (batch_size, seq_len)
caching=False, # cache previously computed attentions
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
Expand Down Expand Up @@ -220,6 +228,7 @@ def generate(
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
caching=caching
)
if fixed_output_length and output.shape[1] < seq_len:
return torch.cat(
Expand Down Expand Up @@ -252,11 +261,24 @@ def generate(
cur_len = text.shape[1]
self.eval()
out = text

if caching:
cache = {"enc": [], "dec": {"self": [], "cross": []}}
else:
cache = {"enc": None, "dec": {"self": None, "cross": None}}

while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]

outputs = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False, cache=cache)

if caching:
cache["enc"] = outputs["text_attentions"]
cache["dec"]["self"] = outputs["attentions"]
cache["dec"]["cross"] = outputs["cross_attentions"]

logits = outputs["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id

Expand Down Expand Up @@ -299,6 +321,7 @@ def _generate_beamsearch(
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
caching=True,
):
device = image_inputs.device
batch_size = image_inputs.shape[0]
Expand Down Expand Up @@ -338,6 +361,11 @@ def _generate_beamsearch(
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))

if caching:
cache = {"enc": [], "dec": {"self": [], "cross": []}}
else:
cache = {"enc": None, "dec": {"self": None, "cross": None}}

while True:

# predicted tokens in cur_len step
Expand All @@ -353,9 +381,15 @@ def _generate_beamsearch(
model_inputs['text'],
embed_cls=False,
image_latent=image_latent,
image_embs=image_embs
image_embs=image_embs,
cache=cache
)

if caching:
cache["enc"] = outputs["text_attentions"]
cache["dec"]["self"] = outputs["attentions"]
cache["dec"]["cross"] = outputs["cross_attentions"]

for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
Expand Down
85 changes: 68 additions & 17 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import OrderedDict
import math
from typing import Callable, Optional, Sequence, Tuple
from typing import Callable, Optional, Sequence, Tuple, List, Union

import torch
from torch import nn
Expand Down Expand Up @@ -220,28 +220,47 @@ def attention(
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
):

k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x

if cache is not None:
q_x = q_x[-1:]
if attn_mask is not None:
if attn_mask.dim() == 2:
attn_mask = attn_mask[-1:]
elif attn_mask.dim() == 3:
attn_mask = attn_mask[:,-1:]

attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(

out = self.attn(
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
)[0]

if cache is not None:
out = torch.cat((cache, out), dim=0)

return out

def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None

x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))

# print(q_x)
attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask, cache=cache)
x = q_x + self.ls_1(attn)
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
return x, attn


class CustomResidualAttentionBlock(nn.Module):
Expand Down Expand Up @@ -310,13 +329,21 @@ def __init__(
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype

def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, cache: Optional[List[torch.Tensor]] = None):
attentions = []
for i, r, in enumerate(self.resblocks):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
x, _ = checkpoint(r, x, None, None, attn_mask, None)
elif cache is None or len(cache) != len(self.resblocks):
x, _ = r(x, attn_mask=attn_mask, cache=None)
else:
x = r(x, attn_mask=attn_mask)
x, attn = r(x, attn_mask=attn_mask, cache=cache[i])
attentions.append(attn)

if cache is not None:
return x, attentions

return x


Expand Down Expand Up @@ -497,7 +524,7 @@ def forward(self, x: torch.Tensor):

if self.output_tokens:
return pooled, tokens

return pooled


Expand Down Expand Up @@ -594,7 +621,7 @@ def build_cls_mask(self, text, cast_dtype: torch.dtype):
def _repeat(self, t, N: int):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def forward(self, text):
def forward(self, text, cache=None):
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]

Expand All @@ -608,7 +635,10 @@ def forward(self, text):

x = x + self.positional_embedding[:seq_len].to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=attn_mask)
if cache is not None:
x, attentions = self.transformer(x, attn_mask=attn_mask, cache=cache)
else:
x = self.transformer(x, attn_mask=attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD

# x.shape = [batch_size, n_ctx, transformer.width]
Expand All @@ -623,8 +653,12 @@ def forward(self, text):
if self.text_projection is not None:
pooled = pooled @ self.text_projection

if self.output_tokens:
if self.output_tokens and cache is None:
return pooled, tokens
elif self.output_tokens and cache is not None:
return pooled, tokens, attentions
elif cache is not None:
return pooled, attentions

return pooled

Expand Down Expand Up @@ -697,25 +731,42 @@ def build_attention_mask(self):
mask.triu_(1) # zero out the lower diagonal
return mask

def forward(self, image_embs, text_embs):
def forward(self, image_embs, text_embs, cache=None):
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
seq_len = text_embs.shape[0]
attentions = []
cross_attentions = []

valid_cache = True
if cache is None or cache["self"] is None or cache["self"] == []:
valid_cache = False

for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
for i, (resblock, cross_attn) in enumerate(zip(self.resblocks, self.cross_attn)):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
else:
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
# Could allow for only caching one or the other, but unsure when that
# would be beneficial over the alternatives
if not valid_cache:
text_embs, attn = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len], cache=None)
text_embs, x_attn = cross_attn(text_embs, k_x=image_embs, v_x=image_embs, cache=None)
else:
text_embs, attn = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len], cache=cache["self"][i])
text_embs, x_attn = cross_attn(text_embs, k_x=image_embs, v_x=image_embs, cache=cache["cross"][i])
attentions.append(attn)
cross_attentions.append(x_attn)

x = text_embs.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)

if self.text_projection is not None:
x = x @ self.text_projection

if cache is not None:
return x, attentions, cross_attentions

return x

Expand Down