Skip to content

Commit

Permalink
Pyre fix
Browse files Browse the repository at this point in the history
Differential Revision: D66717161
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Dec 3, 2024
1 parent af4cb11 commit 3472ed2
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def sync(self, include_optimizer_state: bool = True) -> None:
all_weights = [
w
for emb_kernel in self._modules_to_sync
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for w in emb_kernel.split_embedding_weights()
]
handle = self._replica_pg.allreduce_coalesced(all_weights, opts=opts)
Expand All @@ -755,6 +756,7 @@ def sync(self, include_optimizer_state: bool = True) -> None:
# Sync accumulated square of grad of local optimizer shards
optim_list = []
for emb_kernel in self._modules_to_sync:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
all_optimizer_states = emb_kernel.get_optimizer_state()
momentum1 = [optim["sum"] for optim in all_optimizer_states]
optim_list.extend(momentum1)
Expand Down Expand Up @@ -864,6 +866,8 @@ def _find_sharded_modules(
if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen):
sharded_modules.append(module)
if hasattr(module, "_lookups"):
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is
# not a function.
for lookup in module._lookups:
_find_sharded_modules(lookup)
return
Expand Down

0 comments on commit 3472ed2

Please sign in to comment.