Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W4A8 based on CUTLASS #880

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alexsamardzic
Copy link
Contributor

@alexsamardzic alexsamardzic commented Sep 12, 2024

Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/880

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @alexsamardzic!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Sep 12, 2024

The kernel implements W4A8 GEMM, with float16 scaling factors. The zero point support is to be eventually added later, for now several hacks (to be removed) are put in the code, that will force int8_dynamic_activation_int4_weight to do symmetric quantization for both activation and weight.

There are several points to discuss:

CUTLASS would have to be made a dependency. IMO, the best approach to satisfy the dependency would be to install nvidia-cutlass package, the only problem is that it doesn't always contain latest changes in CUTLASS. An alternative would be to have CUTLASS repo as submodule of this repo, like in PyTorch.

The group quantization may be a problem. Let's say X is input matrix of size MxK, with Xs vector of input scales of size M, and Wis weight matrix of size NxK. If group size parameter is equal to K, then weight scales Ws will be a vector of size N, and an element of output matrix Y of a linear operator would be calculated as follows (let's ignore bias for now, as it's not relevant):

$$y_{i,j}=\sum_{k}xs_{i}\cdot x_{i,k}\cdot w_{j,k}\cdot ws_{j}=xs_{i}\cdot ws_{j}\cdot \sum_{k}x_{i,k}\cdot w_{j,k}$$

The sum in the last expression could be efficiently calculated as mixed integer data types GEMM on tensor cores, and the result could be then updated by mulitplying the scale factors in. However, if group size parameter is less than K, say 32 for example (32 < K, K % 32 == 0), then weight scales will be matrix of size Nx(K/32). In this case, an element of output matrix Y of a linear operator would be calculated as follows:

$$y_{i,j}=\sum_{k}xs_{i}\cdot x_{i,k}\cdot w_{j,k}\cdot ws_{j,k/32}=xs_{i}\cdot \sum_{k}x_{i,k}\cdot w_{j,k}\cdot ws_{j,k/32}$$

Now, the only approach possible in CUTLASS to do this calculation in integer mixed data types on tensor cores would be to split it into K/32 GEMMs, and try to run them at the same time as so-called grouped GEMM. The code would be much more complicated, and also the update with the scaling factors will be still different for each of these individual GEMMs, so I don't think this approach would be performant. So my question here is: Does it make sense to create a quantization different than int8_dynamic_activation_int4_weight, that would match this kernel better, in particular that would not use group quantization for weight at all? (BTW, creating a new quantization, or at least adding a variant of int8_dynamic_activation_int4_weight is needed anyway, as this one is not packing two 4-bit weight values into a byte, that is required by CUTLASS for int8/int4 GEMM.)

Another related issue is zero point handling. Let's say Xz is vector of size M of input zero point values, and Wz is vector of size N of weight zero point values. Then the linear operator calculation, in PyTorch notation would be as follows: Y=((X-Xz)*Xs)@((W-Wz)*Ws).T (again, let's ignore bias), that translates into following calculation for an individual element of output matrix Y:

$$ \begin{array}{lcl} y_{i,j} & = & \sum_{k}xs_{i}\cdot (x_{i,k}-xz_{i})\cdot (w_{j,k}-wz_{j})\cdot ws_{j} \\ & = & xs_{i}\cdot ws_{j}\cdot (\sum_{k}x_{i,k}\cdot w_{j,k}-wz_{j}\sum_{k}a_{i,k}-xz_{i}\sum_{k}w_{k,j}+K\cdot xz_{i}\cdot wz_{j}) \\ \end{array} $$

Only the first expression within parentheses could be calculated on tensor cores as mixed integer data types GEMM, while the sums in the next two expression are best to be pre-calculated in case of weight values, or calculated on the fly during the input quantization. So it seems to me these are also calling for specialized type of quantization. (Note also that if group quantization used, above mentioned complications for Ws are extended to Wz too.)

All comments/suggestions welcome; in particular I'm pretty much new to quantization specifics so please let me know if I'm missing something obvious.

@msaroufim
Copy link
Member

I'm on PTO today and tomorrow so will review asap, apologies for the delay

@cpuhrsch
Copy link
Contributor

@alexsamardzic - Can we use the CUTLASS that ships with PyTorch? As in, should we change PyTorch to ship the headers used to build its CUTLASS kernels / does the PyTorch nightly already ship those? I see the test is using group size 128. I think it's ok if we don't necessarily support all group sizes or shapes right away.

We have some int4 support via the pattern matched in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/post_grad.py#L345-L403 which dispatches to https://github.com/pytorch/pytorch/blob/dab7d646d55a2b6696d51dee4816a6743ec1ae5a/torch/_inductor/kernel/unpack_mixed_mm.py#L76 - would an extension for int4x2 X int8 of this be interesting here?

@alexsamardzic
Copy link
Contributor Author

I'm on PTO today and tomorrow so will review asap, apologies for the delay

Thanks Mark - it's really just a draft, so not yet ready for review, but it would be useful to discuss points that I mentioned in my comment above.

@alexsamardzic
Copy link
Contributor Author

@alexsamardzic - Can we use the CUTLASS that ships with PyTorch? As in, should we change PyTorch to ship the headers used to build its CUTLASS kernels / does the PyTorch nightly already ship those?

This CUTLASS version is also lagging behind. My CUTLASS PR with mixed int4/int8 GEMM is merged after the latest (3.5.1) CUTLASS release, hopefully there will be a new release soon. But in any case, this is a kind of problem that we'll have if we use more CUTLASS from torchao - for lots of time, the torchao build will have to be pointed to a bleeding edge CUTLASS checkout.

I see the test is using group size 128. I think it's ok if we don't necessarily support all group sizes or shapes right away.

It uses group size 128 in order to force weight scale to be a vector, and not a matrix. I tried to explain the issue in my comment above, if group quantization is obligatory here, then it's going to be rather complicated to make this work.

We have some int4 support via the pattern matched in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/post_grad.py#L345-L403 which dispatches to https://github.com/pytorch/pytorch/blob/dab7d646d55a2b6696d51dee4816a6743ec1ae5a/torch/_inductor/kernel/unpack_mixed_mm.py#L76 - would an extension for int4x2 X int8 of this be interesting here?

I'm just looking into the quantization code, to see is it possible to do it there - it's not hard to make this change, but CUTLASS in general doesn't support doing things before GEMM (while fusing operations after GEMM calculated is reasonably well supported), so it would be the best if the quantization code actually put the weight values in int4x2 format.

@alexsamardzic
Copy link
Contributor Author

Updated so that there is a new int8_dynamic_activation_int4_weight_cutlass quantization method available that, for now, would quantize both input and weight symmetrically, and won't use group quantization for weight (so weight scales are always a vector). It should be now possible to try kernel on arbitrary models, if quantized by above quantization method.

@@ -506,6 +508,41 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType.
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type)


def apply_int8_dynamic_activation_int4_weight_quant_cutlass(weight):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be represented as a different Layout for int8 dynamic activation/int4 weight quantization? docs for Packing/Layout can be found in #391 "Layout and Packing" and simplified example in https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! Yes, this will need refinement on this and several other places, as I learn about doing things the "torchao way"; but my main goal initially is to connect the dots, so that some benchmarks could be run, and that we could verify that CUTLASS provides some value here.

@alexsamardzic alexsamardzic force-pushed the w4a8-cutlass branch 7 times, most recently from f6383ca to 02f8805 Compare September 17, 2024 08:26
@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Sep 17, 2024

Made some minor updates, including added support for bfloat16.

Micro-benchmarking script
import copy

import torch

from torchao.utils import (
    TORCH_VERSION_AT_LEAST_2_5,
    unwrap_tensor_subclass,
)
from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int4_weight_cutlass,
)

# FIXME: change this!
_CUTLASS_DIR = ".../cutlass"


class ToyModel(torch.nn.Module):
    def __init__(self, nin, nout1, nout2):
        super().__init__()
        self.linear1 = torch.nn.Linear(nin, nout1)
        self.linear2 = torch.nn.Linear(nout1, nout2, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


methodq = int8_dynamic_activation_int4_weight_cutlass()
compile = False
dtype = torch.float16  # dtype = torch.bfloat16
device = "cuda"
bs, nin, nout1, nout2 = 256, 1024, 2048, 128

inputs = (torch.randn((1, bs, nin), dtype=dtype, device=device),)
model = ToyModel(nin, nout1, nout2).eval().to(dtype).to(device)
modelq = copy.deepcopy(model)

if compile:
    model = torch.compile(model, mode="max-autotune")

quantize_(modelq, methodq)
if not TORCH_VERSION_AT_LEAST_2_5:
    unwrap_tensor_subclass(modelq)

if compile:
    modelq = torch.compile(
        modelq,
        options={
            "max_autotune": True,
            "autotune_in_subproc": False,
            "max_autotune_gemm_backends": "Triton,CUTLASS",
            "cuda.cutlass_dir": _CUTLASS_DIR,
            "use_mixed_mm": True,
        },
    )


if __name__ == "__main__":
    from torchao.utils import benchmark_model

    nruns = 100
    torch._dynamo.reset()
    time = benchmark_model(model, nruns, inputs)
    timeq = benchmark_model(modelq, nruns, inputs)
    print(f"original model mean time  : {time:8.3f}")
    print(f"quantized model mean time : {timeq:8.3f}")
    print(f"speedup by quantization   : {time / timeq:8.3f}")

For particular shapes given in the script above, on A100 the micro-benchmark shows around 2x speedup over the case when float16 MM used, and around 1.8x speedup over the case when bfloat16 MM used. (Note that this is for eager mode execution, as compilation to corresponding CUTLASS kernel is not yet supported by PyTorch.)

Patch to run torchao/_models/llama/generate.py
diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py
index 5fb905d..e5b891b 100644
--- a/torchao/_models/llama/generate.py
+++ b/torchao/_models/llama/generate.py
@@ -206,6 +206,7 @@ def main(
             quantize_,
             int8_weight_only,
             int8_dynamic_activation_int8_weight,
+            int8_dynamic_activation_int4_weight_cutlass,
             int4_weight_only,
             fpx_weight_only,
             uintx_weight_only,
@@ -216,6 +217,8 @@ def main(
             quantize_(model, int8_weight_only())
         if "int8dq" in quantization:
             quantize_(model, int8_dynamic_activation_int8_weight())
+        if "w4a8-cutlass" in quantization:
+            quantize_(model, int8_dynamic_activation_int4_weight_cutlass())
         if "int4wo" in quantization:
             if "hqq" in quantization:
                 use_hqq=True
@@ -414,7 +417,7 @@ if __name__ == '__main__':
     parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
     parser.add_argument('-q', '--quantization', type=str, 
         help=(
-            'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+            'Which quantization techniques to apply: int8dq, w4a8-cutlass, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
             +'autoquant-int4, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
         )
     )
diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py
index 1df3549..1252bb8 100644
--- a/torchao/dtypes/affine_quantized_tensor.py
+++ b/torchao/dtypes/affine_quantized_tensor.py
@@ -1158,6 +1158,7 @@ implements = AffineQuantizedTensor.implements
 # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
 
 def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
+    return False
     return (
         isinstance(input_tensor, AffineQuantizedTensor) and
         _aqt_is_int8_reduced_range(input_tensor) and
diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py
index 3005cb1..451d0e6 100644
--- a/torchao/kernel/intmm.py
+++ b/torchao/kernel/intmm.py
@@ -54,6 +54,8 @@ if TORCH_VERSION_AT_LEAST_2_2:
             and k_is_nonzero_multiple_of_8
         )
 
+        bad_dimensions_for_cublas = False
+
         if device_cpu or bad_dimensions_for_cublas:
             # fallback path
             return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(

With the patch above, I was able to run Llama generator.py script. The command to run is as follows:

python generate.py -q w4a8-cutlass

and the output is as follows (again, this is run on A100):

==========
Average tokens/sec: 10.21
Average Bandwidth: 33.78 GB/s
Peak Memory Usage: 14.22 GB
Model Size: 3.31 GB

while the reference output, for the case when no arguments supplied to generate.py, is as follows:

==========
Average tokens/sec: 32.87
Average Bandwidth: 434.31 GB/s
Peak Memory Usage: 13.62 GB
Model Size: 13.21 GB

So the tokens/sec is more than 3x slower, but this is not even that bad, considering that batch size is 1 here, and that the CUTLASS code has it hard-coded for a block of threads to handle input tile size that is 128 for the same dimension, so most of the work is wasted.

So there is a room for improvement regarding the speed. The text generated is garbage, however. Even for the micro-benchmark above, output values visibly deviate from the values produced when native precision used (but at least they resemble each other).

@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Sep 18, 2024

Made an update - turns out that actually CUTLASS needs a fix (posted below for now), and then generate.py script for Llama model would generate meaningful content.

CUTLASS fix
diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
index 1692cc30..5a1b164c 100644
--- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
+++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
@@ -263,6 +263,44 @@ struct DefaultIteratorsTensorOp<
   static int const kFragmentsPerIteration = 2;
 };
 
+/// Partial specialization for bfloat16 <= int32_t x 8 epilogues avoids shared memory bank conflicts.
+template <
+  typename ThreadblockShape,
+  typename WarpShape,
+  typename InstructionShape,
+  typename ThreadMap
+>
+struct DefaultIteratorsTensorOp<
+  bfloat16_t, 
+  int32_t, 
+  8, 
+  ThreadblockShape, 
+  WarpShape, 
+  InstructionShape, 
+  ThreadMap> {
+  
+  using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
+    WarpShape,
+    InstructionShape,
+    int32_t,
+    32,
+    16,
+    8,
+    8
+  >;
+
+  using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
+    ThreadMap,
+    int32_t,
+    32,
+    16,
+    8,
+    8
+  >;
+
+  static int const kFragmentsPerIteration = 2;
+};
+
 /// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
 /// Threadblock::kN = 256 still has bank conflicts.
 template <

On the other side, I tried with adapting tile sizes processed by block/warp of threads of corresponding CUTLASS kernel, in order to adapt to the fact that batch size is 1 here. Here is an example of such change:

+++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
@@ -418,8 +418,8 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale,
   using ElementA = int8_t;
   using ElementB = cutlass::int4b_t;
   using ElementAccumulator = int32_t;
-  using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
-  using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
+  using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
+  using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
   using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
   AT_DISPATCH_SWITCH(
     input_scale.scalar_type(),

However, tokens/sec is not much improved this way. Thus, the performance of this kernel for Llama model will require more work.

Edit: CUTLASS fix posted upstream here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants