From 34e32c20cbb3ac0c95993f1d83c287cd1335cce5 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 23 Jul 2024 15:13:03 -0700 Subject: [PATCH] Refactor test_jagged_tensor (#2241) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2241 # context * refactor test_jagged_tensor file for better structure of operator test * update the forward, backward, and device tests. * **consolidate** _cpu/_gpu/_meta tests * use a `repeat_test` decorator to iterate a test with a set of arguments # usages * list usage ``` repeat_test( ["cpu", 32, [[3, 4], [5, 6, 7], [8]]], ["cuda", 128, [[96, 256], [512, 128, 768], [1024]]], ) def test_multi_permute_backward( self, device_str: str, batch_size: int, lengths: List[List[int]] ) -> None: if device_str == "cuda" and not torch.cuda.is_available(): return else: device = torch.device(device_str) ``` * dict usage ``` repeat_test(device_str=["cpu", "cuda"], batch_size=[16, 1024]) def test_multi_permute_noncontiguous( self, device_str: str, batch_size: int ) -> None: ``` Reviewed By: ge0405 Differential Revision: D43653576 --- torchrec/sparse/tests/test_jagged_tensor.py | 863 +++++++++----------- torchrec/sparse/tests/utils.py | 27 +- 2 files changed, 391 insertions(+), 499 deletions(-) diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index dfa418064..ff1e54d82 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -9,7 +9,7 @@ import unittest -from typing import Dict, List, Tuple +from typing import Callable, Dict, List, Tuple import torch import torch.utils._pytree as pytree @@ -25,8 +25,10 @@ KeyedJaggedTensor, KeyedTensor, kjt_is_equal, + permute_multi_embedding, + regroup_kts, ) -from torchrec.sparse.tests.utils import build_groups, build_kts +from torchrec.sparse.tests.utils import build_groups, build_kts, repeat_test torch.fx.wrap("len") @@ -1398,502 +1400,6 @@ def test_permute_vb(self) -> None: ) self.assertEqual(permuted_jag_tensor.weights_or_none(), None) - def test_kt_regroup_arguments(self) -> None: - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - for device in ["cpu", "meta", "cuda"]: - if device == "cuda" and not torch.cuda.is_available(): - continue # skip meta tests if no cuda is available - device = torch.device(device) - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - torch.empty(0, device=device), keys, lengths, groups - ) - ref_permutes = [ - [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start - [1, 0, 0, 3, 5, 0], # f3 - [0, 1, 3, 0, 4, 0], # f2 - [1, 2, 5, 0, 6, 0], # f4 - [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence - [2, 2, 0, 9, 8, 0], # f6 - [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary - [1, 3, 11, 3, 7, 0], # f5 - ] - if device.type == "meta": - self.assertEqual( - permutes.shape, (len(ref_permutes), len(ref_permutes[0])) - ) - self.assertEqual(in_shapes.shape, (3,)) - self.assertEqual(out_shapes.shape, (4,)) - else: - self.assertTrue( - torch.equal( - permutes, - torch.tensor(ref_permutes, dtype=torch.int32, device=device), - ) - ) - self.assertEqual(in_shapes.tolist(), [7, 18, 8]) - self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) - self.assertEqual(out_lengths, [8, 4, 17, 10]) - - def test_multi_permute_forward_cpu(self) -> None: - batch_size = 32 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [torch.randn(batch_size, sum(lens), device="cpu") for lens in lengths] - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - values[0], keys, lengths, groups - ) - refs = [[] for _ in groups] - for i in range(permutes.size(0)): - in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() - refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.permute_multi_embedding( - values, permutes, in_shapes, out_shapes, out_lengths - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - def test_multi_permute_forward_meta(self) -> None: - batch_size = 32 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [torch.randn(batch_size, sum(lens), device="meta") for lens in lengths] - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - values[0], keys, lengths, groups - ) - outputs = torch.ops.fbgemm.permute_multi_embedding( - values, permutes, in_shapes, out_shapes, out_lengths - ) - for out, ref in zip(outputs, out_lengths): - self.assertEqual(out.shape, (batch_size, ref)) - - # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) - def test_multi_permute_forward_gpu(self) -> None: - batch_size = 1024 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[96, 256], [512, 128, 768], [1024]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [torch.randn(batch_size, sum(lens), device="cuda") for lens in lengths] - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - values[0], keys, lengths, groups - ) - refs = [[] for _ in groups] - for i in range(permutes.size(0)): - in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() - refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.permute_multi_embedding( - values, permutes, in_shapes, out_shapes, out_lengths - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - def test_multi_permute_backward_cpu(self) -> None: - batch_size = 32 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) - for lens in lengths - ] - ref_values = [v.detach() for v in values] - for v in ref_values: - v.requires_grad = True - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - values[0], keys, lengths, groups - ) - refs = [[] for _ in groups] - for i in range(permutes.size(0)): - in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() - refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.permute_multi_embedding( - values, permutes, in_shapes, out_shapes, out_lengths - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - ref_loss, loss = refs[0].sum(), outputs[0].sum() - for i in range(1, len(refs)): - ref_loss += (i + 1.1) * refs[i].sum() - loss += (i + 1.1) * outputs[i].sum() - ref_loss.backward() - loss.backward() - for val, ref in zip(values, ref_values): - val_grad, ref_grad = val.grad, ref.grad - assert isinstance(val_grad, torch.Tensor) - self.assertTrue(torch.allclose(val_grad, ref_grad)) - - # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) - def test_multi_permute_backward_gpu(self) -> None: - batch_size = 2048 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[96, 256], [512, 128, 768], [1024]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) - for lens in lengths - ] - ref_values = [v.detach() for v in values] - for v in ref_values: - v.requires_grad = True - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - values[0], keys, lengths, groups - ) - refs = [[] for _ in groups] - for i in range(permutes.size(0)): - in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() - refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.permute_multi_embedding( - values, permutes, in_shapes, out_shapes, out_lengths - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - ref_loss, loss = refs[0].sum(), outputs[0].sum() - for i in range(1, len(refs)): - ref_loss += (i + 1.1) * refs[i].sum() - loss += (i + 1.1) * outputs[i].sum() - ref_loss.backward() - loss.backward() - for val, ref in zip(values, ref_values): - val_grad, ref_grad = val.grad, ref.grad - assert isinstance(val_grad, torch.Tensor) - self.assertTrue(torch.allclose(val_grad, ref_grad)) - - def test_multi_permute_noncontiguous_cpu(self) -> None: - batch_size = 32 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(sum(lens), batch_size, device="cpu", requires_grad=True) - for lens in lengths - ] - non_contiguous = [v.t() for v in values] - for value in non_contiguous: - self.assertFalse(value.is_contiguous()) - ref_values = [v.detach() for v in values] - for v in ref_values: - v.requires_grad = True - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - non_contiguous[0], keys, lengths, groups - ) - refs = [[] for _ in groups] - for i in range(permutes.size(0)): - in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() - refs[out_idx].append(ref_values[in_idx][in_start : (in_start + length), :]) - refs = [torch.cat(ref).t() for ref in refs] - outputs = torch.ops.fbgemm.permute_multi_embedding( - non_contiguous, permutes, in_shapes, out_shapes, out_lengths - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - ref_loss, loss = refs[0].sum(), outputs[0].sum() - for i in range(1, len(refs)): - ref_loss += (i + 1.1) * refs[i].sum() - loss += (i + 1.1) * outputs[i].sum() - ref_loss.backward() - loss.backward() - for val, ref in zip(values, ref_values): - val_grad, ref_grad = val.grad, ref.grad - assert isinstance(val_grad, torch.Tensor) - self.assertTrue(torch.allclose(val_grad, ref_grad)) - - # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) - def test_multi_permute_noncontiguous_gpu(self) -> None: - batch_size = 1024 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[96, 256], [512, 128, 768], [1024]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(sum(lens), batch_size, device="cuda", requires_grad=True) - for lens in lengths - ] - non_contiguous = [v.t() for v in values] - for value in non_contiguous: - self.assertFalse(value.is_contiguous()) - ref_values = [v.detach() for v in values] - for v in ref_values: - v.requires_grad = True - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - non_contiguous[0], keys, lengths, groups - ) - refs = [[] for _ in groups] - for i in range(permutes.size(0)): - in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() - refs[out_idx].append(ref_values[in_idx][in_start : (in_start + length), :]) - refs = [torch.cat(ref).t() for ref in refs] - outputs = torch.ops.fbgemm.permute_multi_embedding( - non_contiguous, permutes, in_shapes, out_shapes, out_lengths - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - ref_loss, loss = refs[0].sum(), outputs[0].sum() - for i in range(1, len(refs)): - ref_loss += (i + 1.1) * refs[i].sum() - loss += (i + 1.1) * outputs[i].sum() - ref_loss.backward() - loss.backward() - for val, ref in zip(values, ref_values): - val_grad, ref_grad = val.grad, ref.grad - assert isinstance(val_grad, torch.Tensor) - self.assertTrue(torch.allclose(val_grad, ref_grad)) - - def test_kt_regroup_arguments_op(self) -> None: - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - batch_size = 5 - for device in ["cpu", "meta", "cuda"]: - if device == "cuda" and not torch.cuda.is_available(): - continue # skip meta tests if no cuda is available - device = torch.device(device) - embs = [torch.randn(batch_size, sum(l), device=device) for l in lengths] - permutes, in_shapes, out_shapes, out_lengths = ( - torch.ops.fbgemm.kt_regroup_arguments( - embs[0], - keys, - lengths, - groups, - ) - ) - ref_permutes = [ - [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start - [1, 0, 0, 3, 5, 0], # f3 - [0, 1, 3, 0, 4, 0], # f2 - [1, 2, 5, 0, 6, 0], # f4 - [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence - [2, 2, 0, 9, 8, 0], # f6 - [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary - [1, 3, 11, 3, 7, 0], # f5 - ] - if device.type == "meta": - self.assertEqual( - permutes.shape, (len(ref_permutes), len(ref_permutes[0])) - ) - self.assertEqual(in_shapes.shape, (3,)) - self.assertEqual(out_shapes.shape, (4,)) - else: - self.assertTrue( - torch.equal( - permutes, - torch.tensor(ref_permutes, dtype=torch.int32, device=device), - ) - ) - self.assertEqual(in_shapes.tolist(), [7, 18, 8]) - self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) - self.assertEqual(out_lengths, [8, 4, 17, 10]) - - def test_keyed_tensor_regroup_cpu_forward(self) -> None: - batch_size = 5 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) - for lens in lengths - ] - permutes = [ - [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start - [1, 0, 0, 3, 5, 0], # f3 - [0, 1, 3, 0, 4, 0], # f2 - [1, 2, 5, 0, 6, 0], # f4 - [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence - [2, 2, 0, 9, 8, 0], # f6 - [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary - [1, 3, 11, 3, 7, 0], # f5 - ] - refs = [[] for _ in groups] - for p in permutes: - in_idx, out_idx, in_start, _, length, _ = p - refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.regroup_keyed_tensor( - values, - keys, - lengths, - groups, - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - def test_keyed_tensor_regroup_meta_forward(self) -> None: - batch_size = 5 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(batch_size, sum(lens), device="meta", requires_grad=True) - for lens in lengths - ] - permutes = [ - [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start - [1, 0, 0, 3, 5, 0], # f3 - [0, 1, 3, 0, 4, 0], # f2 - [1, 2, 5, 0, 6, 0], # f4 - [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence - [2, 2, 0, 9, 8, 0], # f6 - [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary - [1, 3, 11, 3, 7, 0], # f5 - ] - refs = [[] for _ in groups] - for p in permutes: - in_idx, out_idx, in_start, _, length, _ = p - refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.regroup_keyed_tensor( - values, - keys, - lengths, - groups, - ) - for out, ref in zip(outputs, refs): - self.assertEqual(out.shape, ref.shape) - - # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) - def test_keyed_tensor_regroup_gpu_forward(self) -> None: - batch_size = 5 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) - for lens in lengths - ] - permutes = [ - [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start - [1, 0, 0, 3, 5, 0], # f3 - [0, 1, 3, 0, 4, 0], # f2 - [1, 2, 5, 0, 6, 0], # f4 - [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence - [2, 2, 0, 9, 8, 0], # f6 - [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary - [1, 3, 11, 3, 7, 0], # f5 - ] - refs = [[] for _ in groups] - for p in permutes: - in_idx, out_idx, in_start, _, length, _ = p - refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.regroup_keyed_tensor( - values, - keys, - lengths, - groups, - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - def test_keyed_tensor_regroup_cpu_backward(self) -> None: - batch_size = 5 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) - for lens in lengths - ] - ref_values = [v.detach() for v in values] - for v in ref_values: - v.requires_grad = True - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - values[0], keys, lengths, groups - ) - refs = [[] for _ in groups] - for i in range(permutes.size(0)): - in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() - refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.regroup_keyed_tensor( - values, - keys, - lengths, - groups, - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - ref_loss, loss = refs[0].sum(), outputs[0].sum() - for i in range(1, len(refs)): - ref_loss += (i + 1.1) * refs[i].sum() - loss += (i + 1.1) * outputs[i].sum() - ref_loss.backward() - loss.backward() - for val, ref in zip(values, ref_values): - val_grad, ref_grad = val.grad, ref.grad - assert isinstance(val_grad, torch.Tensor) - self.assertTrue(torch.allclose(val_grad, ref_grad)) - - # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) - def test_keyed_tensor_regroup_gpu_backward(self) -> None: - batch_size = 5 - keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] - lengths = [[3, 4], [5, 6, 7], [8]] - groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] - values = [ - torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) - for lens in lengths - ] - ref_values = [v.detach() for v in values] - for v in ref_values: - v.requires_grad = True - permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( - values[0], keys, lengths, groups - ) - refs = [[] for _ in groups] - for i in range(permutes.size(0)): - in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() - refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) - refs = [torch.cat(ref, dim=1) for ref in refs] - outputs = torch.ops.fbgemm.regroup_keyed_tensor( - values, - keys, - lengths, - groups, - ) - for out, ref in zip(outputs, refs): - self.assertTrue(torch.allclose(out, ref)) - - ref_loss, loss = refs[0].sum(), outputs[0].sum() - for i in range(1, len(refs)): - ref_loss += (i + 1.1) * refs[i].sum() - loss += (i + 1.1) * outputs[i].sum() - ref_loss.backward() - loss.backward() - for val, ref in zip(values, ref_values): - val_grad, ref_grad = val.grad, ref.grad - assert isinstance(val_grad, torch.Tensor) - self.assertTrue(torch.allclose(val_grad, ref_grad)) - def test_permute_duplicates(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) @@ -2744,6 +2250,75 @@ def test_regroup_multiple_kt(self) -> None: ) ) + @repeat_test( + regroup_func=[ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + ], + device_str=["cpu", "cuda", "meta"], + ) + def test_regroup_kts( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=device, + run_backward=False, + ) + groups = build_groups(kts=kts, num_groups=2) + refs = _regroup_keyed_tensors(kts, groups) + outputs = regroup_func(kts, groups) + for ref, output in zip(refs, outputs): + self.assertEqual(ref.device, output.device) + if device_str == "meta": + self.assertEqual(ref.shape, output.shape) + else: + torch.testing.assert_close(ref, output) + + @repeat_test( + regroup_func=[ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + ], + device_str=["cpu", "cuda", "meta"], + ) + def test_regroup_kts_inference( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + with torch.inference_mode(): + kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=device, + run_backward=False, + ) + groups = build_groups(kts=kts, num_groups=2) + refs = _regroup_keyed_tensors(kts, groups) + outputs = regroup_func(kts, groups) + for ref, output in zip(refs, outputs): + self.assertEqual(ref.device, output.device) + if device_str == "meta": + self.assertEqual(ref.shape, output.shape) + else: + torch.testing.assert_close(ref, output) + def test_regroup_backward_skips_and_duplicates(self) -> None: kts = build_kts( dense_features=20, @@ -3000,6 +2575,298 @@ def test_pytree(self) -> None: self.assertListEqual(unflattened._length_per_key, kt._length_per_key) +class TestKeyedTensorRegroupOp(unittest.TestCase): + @repeat_test(device_str=["cpu", "meta", "cuda"]) + def test_kt_regroup_arguments(self, device_str: str) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + torch.empty(0, device=device), keys, lengths, groups + ) + ref_permutes = [ + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start + [1, 0, 0, 3, 5, 0], # f3 + [0, 1, 3, 0, 4, 0], # f2 + [1, 2, 5, 0, 6, 0], # f4 + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence + [2, 2, 0, 9, 8, 0], # f6 + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary + [1, 3, 11, 3, 7, 0], # f5 + ] + if device_str == "meta": + self.assertEqual(permutes.shape, (len(ref_permutes), len(ref_permutes[0]))) + self.assertEqual(in_shapes.shape, (3,)) + self.assertEqual(out_shapes.shape, (4,)) + else: + self.assertTrue( + torch.equal( + permutes, + torch.tensor(ref_permutes, dtype=torch.int32, device=device), + ) + ) + self.assertEqual(in_shapes.tolist(), [7, 18, 8]) + self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) + self.assertEqual(out_lengths, [8, 4, 17, 10]) + + @repeat_test(device_str=["cpu", "meta", "cuda"], batch_size=[16, 128, 1024]) + def test_multi_permute_forward(self, device_str: str, batch_size: int) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + with torch.inference_mode(): + values = [torch.randn(batch_size, sum(L), device=device) for L in lengths] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + + if device_str == "meta": + for out, ref in zip(outputs, out_lengths): + self.assertEqual(out.shape, (batch_size, ref)) + else: + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out, in_start, _, length, _ = permutes[i].tolist() + refs[out].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + for out, ref in zip(outputs, refs): + torch.testing.assert_close(out, ref) + + @repeat_test( + ["cpu", 32, [[3, 4], [5, 6, 7], [8]]], + ["cuda", 128, [[96, 256], [512, 128, 768], [1024]]], + ) + def test_multi_permute_backward( + self, device_str: str, batch_size: int, lengths: List[List[int]] + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device=device, requires_grad=True) + for lens in lengths + ] + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + + @repeat_test(device_str=["cpu", "cuda"], batch_size=[16, 1024]) + def test_multi_permute_noncontiguous( + self, device_str: str, batch_size: int + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(sum(lens), batch_size, device=device, requires_grad=True) + for lens in lengths + ] + non_contiguous = [v.t() for v in values] + for value in non_contiguous: + self.assertFalse(value.is_contiguous()) + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + non_contiguous[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(ref_values[in_idx][in_start : (in_start + length), :]) + refs = [torch.cat(ref).t() for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + non_contiguous, permutes, in_shapes, out_shapes, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + + @repeat_test(device_str=["cpu", "cuda", "meta"], batch_size=[16, 1024]) + def test_kt_regroup_arguments_op(self, device_str: str, batch_size: int) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + device = torch.device(device) + embs = [torch.randn(batch_size, sum(L), device=device) for L in lengths] + permutes, in_shapes, out_shapes, out_lengths = ( + torch.ops.fbgemm.kt_regroup_arguments( + embs[0], + keys, + lengths, + groups, + ) + ) + ref_permutes = [ + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start + [1, 0, 0, 3, 5, 0], # f3 + [0, 1, 3, 0, 4, 0], # f2 + [1, 2, 5, 0, 6, 0], # f4 + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence + [2, 2, 0, 9, 8, 0], # f6 + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary + [1, 3, 11, 3, 7, 0], # f5 + ] + if device_str == "meta": + self.assertEqual(permutes.shape, (len(ref_permutes), len(ref_permutes[0]))) + self.assertEqual(in_shapes.shape, (3,)) + self.assertEqual(out_shapes.shape, (4,)) + else: + self.assertTrue( + torch.equal( + permutes, + torch.tensor(ref_permutes, dtype=torch.int32, device=device), + ) + ) + self.assertEqual(in_shapes.tolist(), [7, 18, 8]) + self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) + self.assertEqual(out_lengths, [8, 4, 17, 10]) + + @repeat_test(device_str=["cpu", "cuda", "meta"], batch_size=[16, 1024]) + def test_keyed_tensor_regroup_forward( + self, device_str: str, batch_size: int + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + permutes = [ + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start + [1, 0, 0, 3, 5, 0], # f3 + [0, 1, 3, 0, 4, 0], # f2 + [1, 2, 5, 0, 6, 0], # f4 + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence + [2, 2, 0, 9, 8, 0], # f6 + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary + [1, 3, 11, 3, 7, 0], # f5 + ] + with torch.inference_mode(): + values = [ + torch.randn(batch_size, sum(lens), device=device) for lens in lengths + ] + refs = [[] for _ in groups] + for p in permutes: + in_idx, out_idx, in_start, _, length, _ = p + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.regroup_keyed_tensor( + values, + keys, + lengths, + groups, + ) + for out, ref in zip(outputs, refs): + if device_str == "meta": + self.assertEqual(out.shape, ref.shape) + else: + torch.testing.assert_close(out, ref) + + @repeat_test(device_str=["cpu", "cuda"], batch_size=[16, 1024]) + def test_keyed_tensor_regroup_backward( + self, device_str: str, batch_size: int + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device=device, requires_grad=True) + for lens in lengths + ] + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.regroup_keyed_tensor( + values, + keys, + lengths, + groups, + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + + class TestComputeKJTToJTDict(unittest.TestCase): def test_key_lookup(self) -> None: m = ComputeKJTToJTDict() diff --git a/torchrec/sparse/tests/utils.py b/torchrec/sparse/tests/utils.py index d8ae4cad5..5ea4bde1f 100644 --- a/torchrec/sparse/tests/utils.py +++ b/torchrec/sparse/tests/utils.py @@ -7,8 +7,10 @@ # pyre-strict +import functools import random -from typing import List +import unittest +from typing import Any, Callable, List, Sequence import torch from torchrec.sparse.jagged_tensor import KeyedTensor @@ -61,3 +63,26 @@ def build_groups( for group in groups: group.append(random.choice(all_keys)) return groups + + +def repeat_test( + *args: List[Any], **kwargs: Sequence[Any] +) -> Callable[..., Callable[..., None]]: + def decorate(f: Callable[..., None]) -> Callable[..., None]: + @functools.wraps(f) + def decorator(self: unittest.TestCase) -> None: + queue = [(arg, {}) for arg in args] if args else [((), {})] + for k, values in kwargs.items(): + new_queue = [] + for a, d in queue: + for v in values: + new_d = d | {k: v} + new_queue.append((a, new_d)) + queue = new_queue + for a, d in queue: + print(f"running {f.__name__} {a} {d}") + f(self, *a, **d) + + return decorator + + return decorate