Skip to content

Commit

Permalink
Code compile on H100
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 15, 2024
1 parent 85d03de commit c7e5fdf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def all_linear(mod, name):
with torch.no_grad():
result = evaluate(
HFLM(
pretrained=model.to(device),
pretrained=model,
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length),
Expand Down
5 changes: 4 additions & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
input = (
input.contiguous()
) # (it seems the transpose makes cublas check the above j constraint on i)
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
try:
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except:
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
else:
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Expand Down

0 comments on commit c7e5fdf

Please sign in to comment.