Skip to content

Commit

Permalink
Count shard state in HBM usage (#2380)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2380

X-link: facebookresearch/FBGEMM#203

X-link: pytorch/FBGEMM#3114

This PR improve sparse HBM cost by accounting the size of auxilirary state for maintaining UVM cache. As noted in the comment of split_table_batched_embeddings_ops_training, for now the significant space is `4 * hash_size + 8 * cache_slot_size + 8 * cache_slot_size`. This is becoming more nontrivial if we have a table with many rows but few dimensions.

Impact:
- Not UVM-offloaded job: NoOp
- UVM-offloaded job: More balanced memory usage from precise estimation, but for existing UVM jobs with scale up proposer + fixed percentage reservation this might lead to scale up proposer making less aggressive cache scale-up and therefore leading to worse performance. In this case we should tune to more slack reservation percentage .

Reviewed By: sarckk

Differential Revision: D61576911

fbshipit-source-id: 6b501dc63cbe86c5274661b1d985af6a7a0a87c6
  • Loading branch information
levythu authored and facebook-github-bot committed Sep 12, 2024
1 parent a742064 commit 760758f
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 42 deletions.
21 changes: 19 additions & 2 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ def calculate_shard_storages(
sharding_type=sharding_type,
optimizer_class=optimizer_class,
is_inference=is_inference,
clf=caching_ratio if table_cached else None,
)
ddr_specific_sizes: List[int] = _calculate_storage_specific_sizes(
storage=ddr_storage,
Expand Down Expand Up @@ -1395,6 +1396,7 @@ def _calculate_storage_specific_sizes(
sharding_type: str,
optimizer_class: Optional[Type[torch.optim.Optimizer]] = None,
is_inference: bool = False,
clf: Optional[float] = None,
) -> List[int]:
tensor_sizes: List[int] = [
(
Expand All @@ -1410,9 +1412,24 @@ def _calculate_storage_specific_sizes(
math.ceil(tensor_size * optimizer_multipler) for tensor_size in tensor_sizes
]

# If a table has turned on UVM caching (meaning clf is not None), there'll be
# 4x of table hash size and 16x of cache slot size HBM storage cost dedicated to
# cache aux state (note that this is not the cache content itself)
cache_aux_state_sizes: List[int] = (
[0] * len(shard_sizes)
if clf is None
else [math.ceil(size[0] * (4 + clf * 16)) for size in shard_sizes]
)

return [
tensor_size + optimizer_size if not is_inference else tensor_size
for tensor_size, optimizer_size in zip(tensor_sizes, optimizer_sizes)
(
cache_state_size + tensor_size + optimizer_size
if not is_inference
else tensor_size
)
for cache_state_size, tensor_size, optimizer_size in zip(
cache_aux_state_sizes, tensor_sizes, optimizer_sizes
)
]


Expand Down
72 changes: 39 additions & 33 deletions torchrec/distributed/planner/tests/test_enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import math
import unittest
from typing import cast, List
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -37,7 +38,6 @@
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig


EXPECTED_RW_SHARD_SIZES = [
[[13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [9, 20]],
[[14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [12, 40]],
Expand All @@ -52,6 +52,12 @@
[[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]],
]


def get_expected_cache_aux_size(rows: int) -> int:
# 0.2 is the hardcoded cache load factor assumed in this test
return math.ceil(rows * (4 + 0.2 * 16))


EXPECTED_RW_SHARD_STORAGE = [
[
Storage(hbm=166928, ddr=0),
Expand Down Expand Up @@ -98,44 +104,44 @@

EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
[
Storage(hbm=166096, ddr=1040),
Storage(hbm=166096, ddr=1040),
Storage(hbm=166096, ddr=1040),
Storage(hbm=166096, ddr=1040),
Storage(hbm=166096, ddr=1040),
Storage(hbm=166096, ddr=1040),
Storage(hbm=166096, ddr=1040),
Storage(hbm=166032, ddr=720),
Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040),
Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040),
Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040),
Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040),
Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040),
Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040),
Storage(hbm=166096 + get_expected_cache_aux_size(13), ddr=1040),
Storage(hbm=166032 + get_expected_cache_aux_size(9), ddr=720),
],
[
Storage(hbm=1001920, ddr=2240),
Storage(hbm=1001920, ddr=2240),
Storage(hbm=1001920, ddr=2240),
Storage(hbm=1001920, ddr=2240),
Storage(hbm=1001920, ddr=2240),
Storage(hbm=1001920, ddr=2240),
Storage(hbm=1001920, ddr=2240),
Storage(hbm=1001856, ddr=1920),
Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240),
Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240),
Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240),
Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240),
Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240),
Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240),
Storage(hbm=1001920 + get_expected_cache_aux_size(14), ddr=2240),
Storage(hbm=1001856 + get_expected_cache_aux_size(12), ddr=1920),
],
[
Storage(hbm=1004240, ddr=3600),
Storage(hbm=1004240, ddr=3600),
Storage(hbm=1004240, ddr=3600),
Storage(hbm=1004240, ddr=3600),
Storage(hbm=1004240, ddr=3600),
Storage(hbm=1004240, ddr=3600),
Storage(hbm=1004240, ddr=3600),
Storage(hbm=1004240, ddr=3600),
Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600),
Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600),
Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600),
Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600),
Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600),
Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600),
Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600),
Storage(hbm=1004240 + get_expected_cache_aux_size(15), ddr=3600),
],
[
Storage(hbm=2649152, ddr=5440),
Storage(hbm=2649152, ddr=5440),
Storage(hbm=2649152, ddr=5440),
Storage(hbm=2649152, ddr=5440),
Storage(hbm=2649152, ddr=5440),
Storage(hbm=2649152, ddr=5440),
Storage(hbm=2649152, ddr=5440),
Storage(hbm=2648768, ddr=3520),
Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440),
Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440),
Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440),
Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440),
Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440),
Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440),
Storage(hbm=2649152 + get_expected_cache_aux_size(17), ddr=5440),
Storage(hbm=2648768 + get_expected_cache_aux_size(11), ddr=3520),
],
]

Expand Down
31 changes: 28 additions & 3 deletions torchrec/distributed/planner/tests/test_proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import cast, List, Optional
from typing import cast, List, Optional, Type
from unittest.mock import MagicMock

import torch
Expand All @@ -25,6 +25,7 @@
UniformProposer,
)
from torchrec.distributed.planner.shard_estimators import (
_calculate_storage_specific_sizes,
EmbeddingPerfEstimator,
EmbeddingStorageEstimator,
)
Expand Down Expand Up @@ -86,6 +87,22 @@ def cacheability(self) -> float:
return self._cacheability


# Mocking _calculate_storage_specific_sizes to skip cache aux state accounting for
# simpler testing
def mock_calculate_storage_specific_sizes(
storage: int,
shape: torch.Size,
shard_sizes: List[List[int]],
sharding_type: str,
optimizer_class: Optional[Type[torch.optim.Optimizer]] = None,
is_inference: bool = False,
clf: Optional[float] = None,
) -> List[int]:
return _calculate_storage_specific_sizes(
storage, shape, shard_sizes, sharding_type, optimizer_class, is_inference, None
)


class TestProposers(unittest.TestCase):
def setUp(self) -> None:
topology = Topology(world_size=2, compute_device="cuda")
Expand Down Expand Up @@ -466,7 +483,11 @@ def test_allocate_budget(self) -> None:
)
self.assertEqual(increase, budget)

def test_scaleup(self) -> None:
@unittest.mock.patch(
"torchrec.distributed.planner.shard_estimators._calculate_storage_specific_sizes",
side_effect=mock_calculate_storage_specific_sizes,
)
def test_scaleup(self, _) -> None:
tables = [
EmbeddingBagConfig(
num_embeddings=2_000_000,
Expand Down Expand Up @@ -697,7 +718,11 @@ def mock_storage_estimator_func(so: List[ShardingOption]) -> None:
["fused", "fused", "fused_uvm_caching"],
)

def test_budget_shrink(self) -> None:
@unittest.mock.patch(
"torchrec.distributed.planner.shard_estimators._calculate_storage_specific_sizes",
side_effect=mock_calculate_storage_specific_sizes,
)
def test_budget_shrink(self, _) -> None:
tables = [
EmbeddingBagConfig(
num_embeddings=2_000_000,
Expand Down
23 changes: 19 additions & 4 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import math
import unittest
from typing import cast, Dict, List, Tuple

Expand Down Expand Up @@ -623,35 +624,49 @@ def calculate_storage_specific_size_data_provider():
"sharding_type": ShardingType.TABLE_ROW_WISE,
"optimizer_class": torch.optim.SGD,
"expected_storage": [50, 50],
"clf": None,
},
{
"sharding_type": ShardingType.COLUMN_WISE,
"optimizer_class": torch.optim.Adam,
"expected_storage": [150, 150],
"expected_storage": [
150 + math.ceil(5 * (4 + 0.5 * 16)),
150 + math.ceil(5 * (4 + 0.5 * 16)),
],
"clf": 0.5,
},
{
"sharding_type": ShardingType.TABLE_ROW_WISE,
"optimizer_class": None,
"expected_storage": [50, 50],
"expected_storage": [
50 + math.ceil(5 * (4 + 0.0 * 16)),
50 + math.ceil(5 * (4 + 0.0 * 16)),
],
"clf": 0.0,
},
{
"sharding_type": ShardingType.DATA_PARALLEL,
"optimizer_class": trec_optim.RowWiseAdagrad,
"expected_storage": [134, 134],
"expected_storage": [
134 + math.ceil(5 * (4 + 1.0 * 16)),
134 + math.ceil(5 * (4 + 1.0 * 16)),
],
"clf": 1.0,
},
)


class TestEmbeddingStorageEstimator(unittest.TestCase):
def test_calculate_storage_specific_sizes(self) -> None:
for inputs in calculate_storage_specific_size_data_provider():
sharding_type, optimizer_class, expected_storage = inputs.values()
sharding_type, optimizer_class, expected_storage, clf = inputs.values()
estimates = _calculate_storage_specific_sizes(
storage=100,
shape=torch.Size((10, 5, 3)),
shard_sizes=[[5, 5, 3], [5, 5, 3]],
sharding_type=sharding_type.value,
optimizer_class=optimizer_class,
clf=clf,
)

self.assertEqual(estimates, expected_storage)
Expand Down

0 comments on commit 760758f

Please sign in to comment.