Skip to content

Commit

Permalink
W4A8 based on CUTLASS
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Sep 17, 2024
1 parent bd264f9 commit 02f8805
Show file tree
Hide file tree
Showing 10 changed files with 700 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ torchao.quantization
Int4WeightOnlyQuantizer
quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int4_weight_cutlass
int8_dynamic_activation_int8_weight
int4_weight_only
int8_weight_only
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def get_extensions():
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
import cutlass_library
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")
# FIXME: remove this once CUTLASS package updated to include int4/int8 MM
cutlass_include_dir = "/data/quansight/scratch/cutlass/include"

extra_link_args = []
extra_compile_args = {
"cxx": [
Expand All @@ -74,6 +80,7 @@ def get_extensions():
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
"-I" + cutlass_include_dir,
]
}

Expand Down
2 changes: 2 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int4_weight_cutlass,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
Expand All @@ -25,6 +26,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int4_weight_cutlass(),
int8_dynamic_activation_int8_weight(),
]
if do_int4:
Expand Down
1 change: 1 addition & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Quantizer,
TwoStepQuantizer,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int4_weight_cutlass,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
Expand Down
49 changes: 49 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# FIXME: move this test to the appropriate test file!!!

import copy

from torchao.quantization import quantize_
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight_cutlass

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

import pytest


class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(128, 256)
self.linear2 = torch.nn.Linear(256, 128, bias=False)

def forward(self, x):
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
return x


class TestS8S4LinearCUTLASS(TestCase):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_s8s4_linear_cutlass_(self):
# FIXME: remove this!
torch.manual_seed(0)

input = torch.rand((64, 128)).half().cuda()
model = ToyModel().half().cuda()

output_ref = model(input)

modelq = copy.deepcopy(model)
quantize_(modelq, int8_dynamic_activation_int4_weight_cutlass())
output = modelq(input)

assert torch.allclose(output, output_ref, rtol=1e-1, atol=0)


if __name__ == "__main__":
run_tests()
Loading

0 comments on commit 02f8805

Please sign in to comment.