Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for torch.float16 and torch.bfloat16
Summary: X-link: pytorch/FBGEMM#2992 # context * We found the new operator `permute_multi_embedding` can't support `torch.float16` in an inference test * added test to cover the dtype support * before the operator change, we see the following error ``` Failures: 1) torchrec.sparse.tests.test_jagged_tensor.TestKeyedTensorRegroupOp: test_multi_permute_dtype 1) RuntimeError: expected scalar type Float but found Half File "torchrec/sparse/tests/test_jagged_tensor.py", line 2798, in test_multi_permute_dtype outputs = torch.ops.fbgemm.permute_multi_embedding( File "torch/_ops.py", line 1113, in __call__ return self._op(*args, **(kwargs or {})) ``` * suspicion is that in the cpu operator, there are tensor data access with `data_ptr<float>` in the code, which limited the dtype could only be `float32` ``` auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset; auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset; ``` # changes * use `FBGEMM_DISPATCH_FLOATING_TYPES` to dispatch the dtype to template `scalar_t`. * after the change the operator can support `float16`, `bfloat16` WARNING: somehow this operator still can't support `int` types. Reviewed By: sryap Differential Revision: D57143637
- Loading branch information