-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
W4A8 based on CUTLASS #880
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/880
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @alexsamardzic! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
The kernel implements W4A8 GEMM, with float16 scaling factors. The zero point support is to be eventually added later, for now several hacks (to be removed) are put in the code, that will force There are several points to discuss: CUTLASS would have to be made a dependency. IMO, the best approach to satisfy the dependency would be to install The group quantization may be a problem. Let's say The sum in the last expression could be efficiently calculated as mixed integer data types GEMM on tensor cores, and the result could be then updated by mulitplying the scale factors in. However, if group size parameter is less than Now, the only approach possible in CUTLASS to do this calculation in integer mixed data types on tensor cores would be to split it into Another related issue is zero point handling. Let's say Only the first expression within parentheses could be calculated on tensor cores as mixed integer data types GEMM, while the sums in the next two expression are best to be pre-calculated in case of weight values, or calculated on the fly during the input quantization. So it seems to me these are also calling for specialized type of quantization. (Note also that if group quantization used, above mentioned complications for All comments/suggestions welcome; in particular I'm pretty much new to quantization specifics so please let me know if I'm missing something obvious. |
I'm on PTO today and tomorrow so will review asap, apologies for the delay |
@alexsamardzic - Can we use the CUTLASS that ships with PyTorch? As in, should we change PyTorch to ship the headers used to build its CUTLASS kernels / does the PyTorch nightly already ship those? I see the test is using group size 128. I think it's ok if we don't necessarily support all group sizes or shapes right away. We have some int4 support via the pattern matched in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/post_grad.py#L345-L403 which dispatches to https://github.com/pytorch/pytorch/blob/dab7d646d55a2b6696d51dee4816a6743ec1ae5a/torch/_inductor/kernel/unpack_mixed_mm.py#L76 - would an extension for int4x2 X int8 of this be interesting here? |
Thanks Mark - it's really just a draft, so not yet ready for review, but it would be useful to discuss points that I mentioned in my comment above. |
This CUTLASS version is also lagging behind. My CUTLASS PR with mixed int4/int8 GEMM is merged after the latest (3.5.1) CUTLASS release, hopefully there will be a new release soon. But in any case, this is a kind of problem that we'll have if we use more CUTLASS from torchao - for lots of time, the torchao build will have to be pointed to a bleeding edge CUTLASS checkout.
It uses group size 128 in order to force weight scale to be a vector, and not a matrix. I tried to explain the issue in my comment above, if group quantization is obligatory here, then it's going to be rather complicated to make this work.
I'm just looking into the quantization code, to see is it possible to do it there - it's not hard to make this change, but CUTLASS in general doesn't support doing things before GEMM (while fusing operations after GEMM calculated is reasonably well supported), so it would be the best if the quantization code actually put the weight values in int4x2 format. |
1bacd02
to
e1a1ff1
Compare
Updated so that there is a new |
@@ -506,6 +508,41 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType. | |||
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type) | |||
|
|||
|
|||
def apply_int8_dynamic_activation_int4_weight_quant_cutlass(weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be represented as a different Layout
for int8 dynamic activation/int4 weight quantization? docs for Packing/Layout can be found in #391 "Layout and Packing" and simplified example in https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer! Yes, this will need refinement on this and several other places, as I learn about doing things the "torchao way"; but my main goal initially is to connect the dots, so that some benchmarks could be run, and that we could verify that CUTLASS provides some value here.
f6383ca
to
02f8805
Compare
Made some minor updates, including added support for bfloat16. Micro-benchmarking script
For particular shapes given in the script above, on A100 the micro-benchmark shows around 2x speedup over the case when float16 MM used, and around 1.8x speedup over the case when bfloat16 MM used. (Note that this is for eager mode execution, as compilation to corresponding CUTLASS kernel is not yet supported by PyTorch.) Patch to run torchao/_models/llama/generate.py
With the patch above, I was able to run Llama
and the output is as follows (again, this is run on A100):
while the reference output, for the case when no arguments supplied to
So the tokens/sec is more than 3x slower, but this is not even that bad, considering that batch size is 1 here, and that the CUTLASS code has it hard-coded for a block of threads to handle input tile size that is 128 for the same dimension, so most of the work is wasted. So there is a room for improvement regarding the speed. The text generated is garbage, however. Even for the micro-benchmark above, output values visibly deviate from the values produced when native precision used (but at least they resemble each other). |
02f8805
to
575e074
Compare
575e074
to
956fc80
Compare
Made an update - turns out that actually CUTLASS needs a fix (posted below for now), and then CUTLASS fix
On the other side, I tried with adapting tile sizes processed by block/warp of threads of corresponding CUTLASS kernel, in order to adapt to the fact that batch size is 1 here. Here is an example of such change:
However, tokens/sec is not much improved this way. Thus, the performance of this kernel for Llama model will require more work. Edit: CUTLASS fix posted upstream here. |
@msaroufim @cpuhrsch