Skip to content

Commit

Permalink
implementation of fbgemm op - regroup_keyed_tensor (pytorch#2128)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2128

# context
* current production uses `fbgemm.permute_pooled_embs_auto_grad` for `KT.regroup`.
* It has several downsides:
a) it needs to perform a `torch.cat` operation, costing memory and time
b) it only support "no duplicates" in the grouping, otherwise it fallbacks to a slower pytorch native implementation
* new implementation uses `fbgemm.permute_multi_embedding` for the same function
a) it doesn't need `torch.cat`, so saves memory and time
b) it supports "duplicates" in grouping without sacrificing performance

# benchmark results
* stats sheet
|item|baseline|new function|delta perf (%)|notes|
|---|---|---|---|---|
|**runtime**|5.2 ms|2.7 ms|48%|wi/o dups|
|**memory**|1.5 K|1.0 K|33%|w/o dups|
|**runtime**|12.3 ms|2.7 ms|78%|w/ dups|
|**memory**|1.0 K|1.0 K|0%|w/ dups|
* log output
```
  _regroup_keyed_tenors               | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):  13.1 ms | Memory (P90): 1011.0
  permute_multi_embs                  | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   2.7 ms | Memory (P90): 1011.0
  KeyedTensor_regroup                 | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   5.2 ms | Memory (P90): 1517.0
  KTRegroupAsDict                     | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   4.9 ms | Memory (P90): 1517.0
  _regroup_keyed_tenors_dup           | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):  12.3 ms | Memory (P90): 1011.0
  permute_multi_embs_dup              | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   2.7 ms | Memory (P90): 1011.0
  KeyedTensor_regroup_dup             | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):  12.0 ms | Memory (P90): 1011.0
  KTRegroupAsDict_dup                 | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):  11.4 ms | Memory (P90): 1011.0
```
* CPU results are very interesting
```
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 1020     | device: cpu      | Runtime (P90):   0.4 ms | Memory (P90):   0.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 1020     | device: cpu      | Runtime (P90):   0.7 ms | Memory (P90):   0.0
  [prod] KTRegroupAsDict              | B: 1024     | F: 1020     | device: cpu      | Runtime (P90):   0.6 ms | Memory (P90):   0.0
```

Differential Revision: D58649553
  • Loading branch information
Huanyu He authored and facebook-github-bot committed Jul 10, 2024
1 parent 04c8076 commit 1dc3dde
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def permute_multi_embedding(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
permutes, in_shape, out_shape, out_lengths = _kt_regroup_permutes(
permutes, in_shape, out_shape, out_lengths = torch.ops.fbgemm.kt_regroup_permutes(
values[0], keys, lengths, groups
)
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
Expand Down
236 changes: 235 additions & 1 deletion torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,7 @@ def test_kt_regroup_permutes(self) -> None:
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
for device in ["cpu", "meta", "cuda"]:
if device == "cuda" and not torch.cuda.is_available():
continue
continue # skip meta tests if no cuda is available
device = torch.device(device)
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
torch.empty(0, device=device), keys, lengths, groups
Expand Down Expand Up @@ -1584,6 +1584,240 @@ def test_multi_permute_backward_gpu(self) -> None:
assert isinstance(val_grad, torch.Tensor)
self.assertTrue(torch.allclose(val_grad, ref_grad))

def test_kt_regroup_permutes_op(self) -> None:
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
batch_size = 5
for device in ["cpu", "meta", "cuda"]:
if device == "cuda" and not torch.cuda.is_available():
continue # skip meta tests if no cuda is available
device = torch.device(device)
embs = [torch.randn(batch_size, sum(l), device=device) for l in lengths]
permutes, in_shapes, out_shapes, out_lengths = (
torch.ops.fbgemm.kt_regroup_permutes(
embs[0],
keys,
lengths,
groups,
)
)
ref_permutes = [
[0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary
[1, 3, 11, 3, 7, 0], # f5
]
if device.type == "meta":
self.assertEqual(
permutes.shape, (len(ref_permutes), len(ref_permutes[0]))
)
self.assertEqual(in_shapes.shape, (3,))
self.assertEqual(out_shapes.shape, (4,))
else:
self.assertTrue(
torch.equal(
permutes,
torch.tensor(ref_permutes, dtype=torch.int32, device=device),
)
)
self.assertEqual(in_shapes.tolist(), [7, 18, 8])
self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10])
self.assertEqual(out_lengths, [8, 4, 17, 10])

def test_keyed_tensor_regroup_cpu_forward(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
permutes = [
[0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary
[1, 3, 11, 3, 7, 0], # f5
]
refs = [[] for _ in groups]
for p in permutes:
in_idx, out_idx, in_start, _, length, _ = p
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.regroup_keyed_tensor(
values,
keys,
lengths,
groups,
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

def test_keyed_tensor_regroup_meta_forward(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="meta", requires_grad=True)
for lens in lengths
]
permutes = [
[0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary
[1, 3, 11, 3, 7, 0], # f5
]
refs = [[] for _ in groups]
for p in permutes:
in_idx, out_idx, in_start, _, length, _ = p
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.regroup_keyed_tensor(
values,
keys,
lengths,
groups,
)
for out, ref in zip(outputs, refs):
self.assertEqual(out.shape, ref.shape)

# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
def test_keyed_tensor_regroup_gpu_forward(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
permutes = [
[0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary
[1, 3, 11, 3, 7, 0], # f5
]
refs = [[] for _ in groups]
for p in permutes:
in_idx, out_idx, in_start, _, length, _ = p
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.regroup_keyed_tensor(
values,
keys,
lengths,
groups,
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

def test_keyed_tensor_regroup_cpu_backward(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
ref_values = [v.detach() for v in values]
for v in ref_values:
v.requires_grad = True
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
values[0], keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(permutes.size(0)):
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.regroup_keyed_tensor(
values,
keys,
lengths,
groups,
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

ref_loss, loss = refs[0].sum(), outputs[0].sum()
for i in range(1, len(refs)):
ref_loss += (i + 1.1) * refs[i].sum()
loss += (i + 1.1) * outputs[i].sum()
ref_loss.backward()
loss.backward()
for val, ref in zip(values, ref_values):
val_grad, ref_grad = val.grad, ref.grad
assert isinstance(val_grad, torch.Tensor)
self.assertTrue(torch.allclose(val_grad, ref_grad))

# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
def test_keyed_tensor_regroup_gpu_backward(self) -> None:
batch_size = 5
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
ref_values = [v.detach() for v in values]
for v in ref_values:
v.requires_grad = True
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
values[0], keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(permutes.size(0)):
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.regroup_keyed_tensor(
values,
keys,
lengths,
groups,
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

ref_loss, loss = refs[0].sum(), outputs[0].sum()
for i in range(1, len(refs)):
ref_loss += (i + 1.1) * refs[i].sum()
loss += (i + 1.1) * outputs[i].sum()
ref_loss.backward()
loss.backward()
for val, ref in zip(values, ref_values):
val_grad, ref_grad = val.grad, ref.grad
assert isinstance(val_grad, torch.Tensor)
self.assertTrue(torch.allclose(val_grad, ref_grad))

def test_permute_duplicates(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
Expand Down

0 comments on commit 1dc3dde

Please sign in to comment.