Skip to content

Commit

Permalink
add support for torch.float16 and torch.bfloat16
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/FBGEMM#2992

# context
* We found the new operator `permute_multi_embedding` can't support `torch.float16` in an inference test
* added test to cover the dtype support
* before the operator change, we see the following error
```
Failures:

  1) torchrec.sparse.tests.test_jagged_tensor.TestKeyedTensorRegroupOp: test_multi_permute_dtype
    1) RuntimeError: expected scalar type Float but found Half
      File "torchrec/sparse/tests/test_jagged_tensor.py", line 2798, in test_multi_permute_dtype
        outputs = torch.ops.fbgemm.permute_multi_embedding(
      File "torch/_ops.py", line 1113, in __call__
        return self._op(*args, **(kwargs or {}))
```
* suspicion is that in the cpu operator, there are tensor data access with `data_ptr<float>` in the code, which limited the dtype could only be `float32`
```
          auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset;
          auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset;
```

# changes
* use `FBGEMM_DISPATCH_FLOATING_TYPES` to dispatch the dtype to template `scalar_t`.
* after the change the operator can support `float16`, `bfloat16`

WARNING: somehow this operator still can't support `int` types.

Reviewed By: sryap

Differential Revision: D57143637
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 15, 2024
1 parent 5e30669 commit 11dcb3e
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2775,6 +2775,52 @@ def test_multi_permute_forward(self, device_str: str, batch_size: int) -> None:
for out, ref in zip(outputs, refs):
torch.testing.assert_close(out, ref)

@repeat_test(
device_str=["meta", "cpu", "cuda"],
dtype=[
# torch.int,
# torch.uint8,
# torch.int8,
# torch.int16,
# torch.float64,
torch.float,
torch.float32,
torch.float16,
torch.bfloat16,
],
)
def test_multi_permute_dtype(self, device_str: str, dtype: torch.dtype) -> None:
if device_str == "cuda" and not torch.cuda.is_available():
return
else:
device = torch.device(device_str)
batch_size = 4
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(L), device=device, dtype=dtype) for L in lengths
]
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments(
values[0], keys, lengths, groups
)
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths
)

if device_str == "meta":
for out, ref in zip(outputs, out_lengths):
self.assertEqual(out.shape, (batch_size, ref))
else:
refs = [[] for _ in groups]
for i in range(permutes.size(0)):
in_idx, out, in_start, _, length, _ = permutes[i].tolist()
refs[out].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
for out, ref in zip(outputs, refs):
torch.testing.assert_close(out, ref)
self.assertEqual(out.dtype, ref.dtype)

@repeat_test(
["cpu", 32, [[3, 4], [5, 6, 7], [8]]],
["cuda", 128, [[96, 256], [512, 128, 768], [1024]]],
Expand Down

0 comments on commit 11dcb3e

Please sign in to comment.