Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add test for supporting torch.float16 and torch.bfloat16 #2300

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Commits on Aug 15, 2024

  1. add support for torch.float16 and torch.bfloat16

    Summary:
    X-link: pytorch/FBGEMM#2992
    
    # context
    * We found the new operator `permute_multi_embedding` can't support `torch.float16` in an inference test
    * added test to cover the dtype support
    * before the operator change, we see the following error
    ```
    Failures:
    
      1) torchrec.sparse.tests.test_jagged_tensor.TestKeyedTensorRegroupOp: test_multi_permute_dtype
        1) RuntimeError: expected scalar type Float but found Half
          File "torchrec/sparse/tests/test_jagged_tensor.py", line 2798, in test_multi_permute_dtype
            outputs = torch.ops.fbgemm.permute_multi_embedding(
          File "torch/_ops.py", line 1113, in __call__
            return self._op(*args, **(kwargs or {}))
    ```
    * suspicion is that in the cpu operator, there are tensor data access with `data_ptr<float>` in the code, which limited the dtype could only be `float32`
    ```
              auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset;
              auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset;
    ```
    
    # changes
    * use `FBGEMM_DISPATCH_FLOATING_TYPES` to dispatch the dtype to template `scalar_t`.
    * after the change the operator can support `float16`, `bfloat16`
    
    WARNING: somehow this operator still can't support `int` types.
    
    Reviewed By: sryap
    
    Differential Revision: D57143637
    TroyGarden authored and facebook-github-bot committed Aug 15, 2024
    Configuration menu
    Copy the full SHA
    11dcb3e View commit details
    Browse the repository at this point in the history