Skip to content

Commit

Permalink
2024-04-12 nightly release (8417057)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Apr 12, 2024
1 parent fe41c85 commit dc2d95b
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 143 deletions.
55 changes: 40 additions & 15 deletions torchrec/distributed/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,51 @@

from typing import List, Tuple

import torch

from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torchrec.distributed.quant_embedding import ShardedQuantEmbeddingCollection

from torchrec.distributed.quant_embeddingbag import ShardedQuantEmbeddingBagCollection


def get_tbe_specs_from_sqebc(
sqebc: ShardedQuantEmbeddingBagCollection,
def get_tbes_from_sharded_module(
module: torch.nn.Module,
) -> List[IntNBitTableBatchedEmbeddingBagsCodegen]:
assert type(module) in [
ShardedQuantEmbeddingBagCollection,
ShardedQuantEmbeddingCollection,
], "Only support ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection for get TBEs"
tbes = []
for lookup in module._lookups:
for lookup_per_rank in lookup._embedding_lookups_per_rank:
for emb_module in lookup_per_rank._emb_modules:
tbes.append(emb_module._emb_module)
return tbes


def get_tbe_specs_from_sharded_module(
module: torch.nn.Module,
) -> List[
Tuple[str, int, int, str, str]
]: # # tuple of (feature_names, rows, dims, str(SparseType), str(EmbeddingLocation/placement))
assert type(module) in [
ShardedQuantEmbeddingBagCollection,
ShardedQuantEmbeddingCollection,
], "Only support ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection for get TBE specs"
tbe_specs = []
for lookup in sqebc._lookups:
for lookup_per_rank in lookup._embedding_lookups_per_rank:
for emb_module in lookup_per_rank._emb_modules:
for spec in emb_module._emb_module.embedding_specs:
tbe_specs.append(
(
spec[0],
spec[1],
spec[2],
str(spec[3]),
str(spec[4]),
)
)
tbes = get_tbes_from_sharded_module(module)
for tbe in tbes:
for spec in tbe.embedding_specs:
tbe_specs.append(
(
spec[0],
spec[1],
spec[2],
str(spec[3]),
str(spec[4]),
)
)
return tbe_specs
19 changes: 19 additions & 0 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def __init__(
self._batch_size: int = batch_size
self._constraints = constraints
self._sharder_map: Dict[str, ModuleSharder[nn.Module]] = {}
memory_type = "hbm_cap" if topology.compute_device == "cuda" else "ddr_cap"
self._device_memory_sizes: Optional[
List[int]
] = ( # only used with custom topology where memory is different within a topology
topology._custom_topology_data.get_data(memory_type)
if topology._custom_topology_data
and topology._custom_topology_data.has_data(memory_type)
else None
)

if estimator:
self._estimators: List[ShardEstimator] = (
Expand Down Expand Up @@ -130,8 +139,13 @@ def enumerate(
bounds_check_mode,
feature_names,
output_dtype,
device_group,
) = _extract_constraints_for_param(self._constraints, name)

# skip for other device groups
if device_group and device_group != self._compute_device:
continue

sharding_options_per_table: List[ShardingOption] = []

for sharding_type in self._filter_sharding_types(
Expand All @@ -151,6 +165,7 @@ def enumerate(
local_world_size=self._local_world_size,
sharding_type=sharding_type,
col_wise_shard_dim=col_wise_shard_dim,
device_memory_sizes=self._device_memory_sizes,
)
dependency = None
if isinstance(child_module, EmbeddingTower):
Expand Down Expand Up @@ -278,6 +293,7 @@ def _extract_constraints_for_param(
Optional[BoundsCheckMode],
Optional[List[str]],
Optional[DataType],
Optional[str],
]:
input_lengths = [POOLING_FACTOR]
col_wise_shard_dim = None
Expand All @@ -287,6 +303,7 @@ def _extract_constraints_for_param(
bounds_check_mode = None
feature_names = None
output_dtype = None
device_group = None

if constraints and constraints.get(name):
input_lengths = constraints[name].pooling_factors
Expand All @@ -297,6 +314,7 @@ def _extract_constraints_for_param(
bounds_check_mode = constraints[name].bounds_check_mode
feature_names = constraints[name].feature_names
output_dtype = constraints[name].output_dtype
device_group = constraints[name].device_group

return (
input_lengths,
Expand All @@ -307,6 +325,7 @@ def _extract_constraints_for_param(
bounds_check_mode,
feature_names,
output_dtype,
device_group,
)


Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
)
from torchrec.distributed.planner.types import ParameterConstraints, Perf, Topology
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.test_utils.infer_utils import quantize
from torchrec.distributed.test_utils.test_model import TestEBCSharder, TestSparseNN
from torchrec.distributed.tests.test_quant_model_parallel import _quantize
from torchrec.distributed.tests.test_sequence_model import TestSequenceSparseNN
from torchrec.distributed.types import (
CacheParams,
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_inference_1_table_perf(self) -> None:
)
]
model = TestSparseNN(tables=tables, weighted_tables=[])
quant_model = _quantize(model, inplace=True)
quant_model = quantize(model, inplace=True)

inference_estimator = EmbeddingPerfEstimator(
topology=self.topology, is_inference=True
Expand Down
33 changes: 33 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,37 @@ class DeviceHardware:
perf: Perf


class CustomTopologyData:
"""
Custom device data for individual device in a topology.
"""

supported_fields = ["ddr_cap", "hbm_cap"]

def __init__(
self,
data: Dict[str, List[int]],
world_size: int,
) -> None:
assert all(
key in self.supported_fields for key in data.keys()
), f"{data.keys()} not supported in CustomTopologyData"
assert all(
len(v) == world_size for v in data.values()
), f"{data.values()} must be positive"
self._data = data
self._world_size = world_size

def get_data(self, key: str) -> List[int]:
assert (
key in self.supported_fields
), f"{key} not supported in CustomTopologyData"
return self._data[key]

def has_data(self, key: str) -> bool:
return key in self._data


class Topology:
def __init__(
self,
Expand All @@ -154,6 +185,7 @@ def __init__(
intra_host_bw: float = INTRA_NODE_BANDWIDTH,
inter_host_bw: float = CROSS_NODE_BANDWIDTH,
bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER,
custom_topology_data: Optional[CustomTopologyData] = None,
) -> None:
"""
Representation of a network of devices in a cluster.
Expand Down Expand Up @@ -191,6 +223,7 @@ def __init__(
self._intra_host_bw = intra_host_bw
self._inter_host_bw = inter_host_bw
self._bwd_compute_multiplier = bwd_compute_multiplier
self._custom_topology_data = custom_topology_data

@property
def compute_device(self) -> str:
Expand Down
41 changes: 40 additions & 1 deletion torchrec/distributed/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def calculate_shard_sizes_and_offsets(
local_world_size: int,
sharding_type: str,
col_wise_shard_dim: Optional[int] = None,
device_memory_sizes: Optional[List[int]] = None,
) -> Tuple[List[List[int]], List[List[int]]]:
"""
Calculates sizes and offsets for tensor sharded according to provided sharding type.
Expand All @@ -103,12 +104,23 @@ def calculate_shard_sizes_and_offsets(

(rows, columns) = tensor.shape

if device_memory_sizes is not None:
assert (
sharding_type == ShardingType.ROW_WISE.value
), "Currently only support uneven sharding for row_wise sharding"

if sharding_type == ShardingType.DATA_PARALLEL.value:
return [[rows, columns]] * world_size, [[0, 0]] * world_size
elif sharding_type == ShardingType.TABLE_WISE.value:
return [[rows, columns]], [[0, 0]]
elif sharding_type == ShardingType.ROW_WISE.value:
return _calculate_rw_shard_sizes_and_offsets(rows, world_size, columns)
return (
_calculate_rw_shard_sizes_and_offsets(rows, world_size, columns)
if not device_memory_sizes
else _calculate_uneven_rw_shard_sizes_and_offsets(
rows, world_size, columns, device_memory_sizes
)
)
elif sharding_type == ShardingType.TABLE_ROW_WISE.value:
return _calculate_rw_shard_sizes_and_offsets(rows, local_world_size, columns)
elif (
Expand Down Expand Up @@ -159,6 +171,33 @@ def _calculate_rw_shard_sizes_and_offsets(
return shard_sizes, shard_offsets


def _calculate_uneven_rw_shard_sizes_and_offsets(
hash_size: int, num_devices: int, columns: int, device_memory_sizes: List[int]
) -> Tuple[List[List[int]], List[List[int]]]:
assert num_devices == len(device_memory_sizes), "must provide all the memory size"
total_size = sum(device_memory_sizes)
shard_sizes: List[List[int]] = []
last_rank = num_devices - 1

processed_total_rows = 0

for rank in range(num_devices):
if rank < last_rank:
local_row: int = int(hash_size * (device_memory_sizes[rank] / total_size))
processed_total_rows += local_row
elif rank == last_rank:
local_row: int = hash_size - processed_total_rows
else:
local_row: int = 0
shard_sizes.append([local_row, columns])
shard_offsets = [[0, 0]]

for i in range(num_devices - 1):
shard_offsets.append([shard_sizes[i][0] + shard_offsets[i][0], 0])

return shard_sizes, shard_offsets


def _find_base_dim(lower_bound: int, dim: int) -> int:
for i in range(lower_bound, dim):
if dim % i == 0 and i % 4 == 0:
Expand Down
19 changes: 19 additions & 0 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
data_type_to_sparse_type,
dtype_to_data_type,
EmbeddingBagConfig,
QuantConfig,
)
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
Expand Down Expand Up @@ -248,6 +249,7 @@ def quantize(
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
weight_dtype: torch.dtype = torch.qint8,
per_table_weight_dtypes: Optional[Dict[str, torch.dtype]] = None,
) -> torch.nn.Module:
module_types: List[Type[torch.nn.Module]] = [
torchrec.modules.embedding_modules.EmbeddingBagCollection,
Expand All @@ -264,6 +266,14 @@ def quantize(
activation=quant.PlaceholderObserver.with_args(dtype=output_type),
weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype),
)

if per_table_weight_dtypes:
qconfig = QuantConfig(
activation=quant.PlaceholderObserver.with_args(dtype=output_type),
weight=quant.PlaceholderObserver.with_args(dtype=torch.quint8),
per_table_weight_dtype=per_table_weight_dtypes,
)

return quant.quantize_dynamic(
module,
qconfig_spec={
Expand All @@ -285,6 +295,7 @@ def quantize_fpebc(
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
weight_dtype: torch.dtype = torch.qint8,
per_table_weight_dtypes: Optional[Dict[str, torch.dtype]] = None,
) -> torch.nn.Module:
module_types: List[Type[torch.nn.Module]] = [
torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection,
Expand All @@ -300,6 +311,14 @@ def quantize_fpebc(
activation=quant.PlaceholderObserver.with_args(dtype=output_type),
weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype),
)

if per_table_weight_dtypes:
qconfig = QuantConfig(
activation=quant.PlaceholderObserver.with_args(dtype=output_type),
weight=quant.PlaceholderObserver.with_args(dtype=torch.quint8),
per_table_weight_dtype=per_table_weight_dtypes,
)

return quant.quantize_dynamic(
module,
qconfig_spec={
Expand Down
Loading

0 comments on commit dc2d95b

Please sign in to comment.