Skip to content

Commit

Permalink
2025-01-03 nightly release (00d8ed2)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 3, 2025
1 parent d0e11e7 commit 5fd2811
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 319 deletions.
12 changes: 12 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch.distributed._tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
EmbeddingShardingContext,
Expand Down Expand Up @@ -73,6 +74,7 @@
add_params_from_parameter_sharding,
append_prefix,
convert_to_fbgemm_types,
create_global_tensor_shape_stride_from_metadata,
maybe_annotate_embedding_event,
merge_fused_params,
none_throws,
Expand Down Expand Up @@ -918,6 +920,14 @@ def _initialize_torch_state(self) -> None: # noqa
)
)
else:
shape, stride = create_global_tensor_shape_stride_from_metadata(
none_throws(self.module_sharding_plan[table_name]),
(
self._env.node_group_size
if isinstance(self._env, ShardingEnv2D)
else get_local_size(self._env.world_size)
),
)
# empty shard case
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
Expand All @@ -927,6 +937,8 @@ def _initialize_torch_state(self) -> None: # noqa
),
device_mesh=self._env.device_mesh,
run_check=False,
shape=shape,
stride=stride,
)
)
else:
Expand Down
54 changes: 27 additions & 27 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ def __init__(
max_feature_lengths: Optional[Dict[str, int]] = None,
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
over_arch_clazz: Type[nn.Module] = TestOverArch,
preproc_module: Optional[nn.Module] = None,
postproc_module: Optional[nn.Module] = None,
) -> None:
super().__init__(
tables=cast(List[BaseEmbeddingConfig], tables),
Expand Down Expand Up @@ -1229,7 +1229,7 @@ def __init__(
"dummy_ones",
torch.ones(1, device=dense_device),
)
self.preproc_module = preproc_module
self.postproc_module = postproc_module

def sparse_forward(self, input: ModelInput) -> KeyedTensor:
return self.sparse(
Expand All @@ -1256,8 +1256,8 @@ def forward(
self,
input: ModelInput,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.preproc_module:
input = self.preproc_module(input)
if self.postproc_module:
input = self.postproc_module(input)
return self.dense_forward(input, self.sparse_forward(input))


Expand Down Expand Up @@ -1749,18 +1749,18 @@ def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:

class TestModelWithPreproc(nn.Module):
"""
Basic module with up to 3 preproc modules:
- preproc on idlist_features for non-weighted EBC
- preproc on idscore_features for weighted EBC
- optional preproc on model input shared by both EBCs
Basic module with up to 3 postproc modules:
- postproc on idlist_features for non-weighted EBC
- postproc on idscore_features for weighted EBC
- optional postproc on model input shared by both EBCs
Args:
tables,
weighted_tables,
device,
preproc_module,
postproc_module,
num_float_features,
run_preproc_inline,
run_postproc_inline,
Example:
>>> TestModelWithPreproc(tables, weighted_tables, device)
Expand All @@ -1774,9 +1774,9 @@ def __init__(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
device: torch.device,
preproc_module: Optional[nn.Module] = None,
postproc_module: Optional[nn.Module] = None,
num_float_features: int = 10,
run_preproc_inline: bool = False,
run_postproc_inline: bool = False,
) -> None:
super().__init__()
self.dense = TestDenseArch(num_float_features, device)
Expand All @@ -1790,17 +1790,17 @@ def __init__(
is_weighted=True,
device=device,
)
self.preproc_nonweighted = TestPreprocNonWeighted()
self.preproc_weighted = TestPreprocWeighted()
self._preproc_module = preproc_module
self._run_preproc_inline = run_preproc_inline
self.postproc_nonweighted = TestPreprocNonWeighted()
self.postproc_weighted = TestPreprocWeighted()
self._postproc_module = postproc_module
self._run_postproc_inline = run_postproc_inline

def forward(
self,
input: ModelInput,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runs preprco for EBC and weighted EBC, optionally runs preproc for input
Runs preprco for EBC and weighted EBC, optionally runs postproc for input
Args:
input
Expand All @@ -1809,20 +1809,20 @@ def forward(
"""
modified_input = input

if self._preproc_module is not None:
modified_input = self._preproc_module(modified_input)
elif self._run_preproc_inline:
if self._postproc_module is not None:
modified_input = self._postproc_module(modified_input)
elif self._run_postproc_inline:
idlist_features = modified_input.idlist_features
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
idlist_features.keys(), # pyre-ignore [6]
idlist_features.values(), # pyre-ignore [6]
idlist_features.lengths(), # pyre-ignore [16]
)

modified_idlist_features = self.preproc_nonweighted(
modified_idlist_features = self.postproc_nonweighted(
modified_input.idlist_features
)
modified_idscore_features = self.preproc_weighted(
modified_idscore_features = self.postproc_weighted(
modified_input.idscore_features
)
ebc_out = self.ebc(modified_idlist_features[0])
Expand All @@ -1834,15 +1834,15 @@ def forward(

class TestNegSamplingModule(torch.nn.Module):
"""
Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing
Args:
extra_input
has_params
Example:
>>> preproc = TestNegSamplingModule(extra_input)
>>> out = preproc(in)
>>> postproc = TestNegSamplingModule(extra_input)
>>> out = postproc(in)
Returns:
ModelInput
Expand Down Expand Up @@ -1906,8 +1906,8 @@ class TestPositionWeightedPreprocModule(torch.nn.Module):
Args: None
Example:
>>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
>>> out = preproc(in)
>>> postproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
>>> out = postproc(in)
Returns:
ModelInput
"""
Expand Down
Loading

0 comments on commit 5fd2811

Please sign in to comment.