Skip to content

Commit

Permalink
Code compile on H100
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 12, 2024
1 parent 85d03de commit ed4ce1c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
3 changes: 2 additions & 1 deletion scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars
model = autoquant(model.to(device=device))

if quantization != "autoquant" and compile:
model = model.to(device)
model = torch.compile(model, mode="max-autotune", fullgraph=True)

if sparsity == "semi_sparse":
Expand All @@ -89,7 +90,7 @@ def all_linear(mod, name):
with torch.no_grad():
result = evaluate(
HFLM(
pretrained=model.to(device),
pretrained=model,
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length),
Expand Down
5 changes: 4 additions & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
input = (
input.contiguous()
) # (it seems the transpose makes cublas check the above j constraint on i)
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
try:
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except:
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
else:
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Expand Down

0 comments on commit ed4ce1c

Please sign in to comment.