Skip to content

Commit

Permalink
Fix AttentionPooler batch_first change, remove device arg from logit …
Browse files Browse the repository at this point in the history
…processor as it's very very new, move sot/eos to tensor beforehand
  • Loading branch information
rwightman committed Jul 3, 2024
1 parent 6b555de commit 6a5068e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
16 changes: 12 additions & 4 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ def _build_text_decoder_tower(
return decoder


def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
if not isinstance(token_id, torch.Tensor):
if isinstance(token_id, int):
token_id = [token_id]
token_id = torch.tensor(token_id, device=device)
return token_id


class CoCa(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -218,12 +226,12 @@ def generate(
device = image.device

with torch.no_grad():
sot_token_id = 49406 if sot_token_id is None else sot_token_id
eos_token_id = 49407 if eos_token_id is None else eos_token_id
sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id, device=device),
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
Expand All @@ -248,7 +256,7 @@ def generate(
pad_len = seq_len - output.shape[1]
return torch.cat((
output,
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * self.pad_id
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
),
dim=1
)
Expand Down
5 changes: 2 additions & 3 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ def __init__(
self.ln_k = norm_layer(context_dim)

def forward(self, x: torch.Tensor):
N = x.shape[0]
x = self.ln_k(x)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0]
out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
return out


Expand Down Expand Up @@ -823,7 +823,6 @@ def __init__(
output_dim: int = 512,
batch_first: bool = True,
):

super().__init__(
width=width,
layers=layers,
Expand Down

0 comments on commit 6a5068e

Please sign in to comment.