Skip to content

Commit

Permalink
2024-12-17 nightly release (e6e4f6c)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 17, 2024
1 parent 7803b2d commit e4eb029
Show file tree
Hide file tree
Showing 24 changed files with 25 additions and 271 deletions.
1 change: 0 additions & 1 deletion torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#!/usr/bin/env python3

import argparse
import copy
import logging
import os
import time
Expand Down
109 changes: 0 additions & 109 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,115 +739,6 @@ def all2all_sequence_sync(
return sharded_output_embeddings.view(-1, D)


def alltoallv(
inputs: List[Tensor],
out_split: Optional[List[int]] = None,
per_rank_split_lengths: Optional[List[int]] = None,
group: Optional[dist.ProcessGroup] = None,
codecs: Optional[QuantizedCommCodecs] = None,
) -> Awaitable[List[Tensor]]:
"""
Performs `alltoallv` operation for a list of input embeddings. Each process scatters
the list to all processes in the group.
Args:
inputs (List[Tensor]): list of tensors to scatter, one per rank. The tensors in
the list usually have different lengths.
out_split (Optional[List[int]]): output split sizes (or dim_sum_per_rank), if
not specified, we will use `per_rank_split_lengths` to construct a output
split with the assumption that all the embs have the same dimension.
per_rank_split_lengths (Optional[List[int]]): split lengths per rank. If not
specified, the `out_split` must be specified.
group (Optional[dist.ProcessGroup]): the process group to work on. If None, the
default process group will be used.
codecs (Optional[QuantizedCommCodecs]): quantized communication codecs.
Returns:
Awaitable[List[Tensor]]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting list of tensors.
.. warning::
`alltoallv` is experimental and subject to change.
"""

if group is None:
group = dist.distributed_c10d._get_default_group()

world_size: int = group.size()
my_rank: int = group.rank()

B_global = inputs[0].size(0)

D_local_list = []
for e in inputs:
D_local_list.append(e.size()[1])

B_local, B_local_list = _get_split_lengths_by_len(world_size, my_rank, B_global)

if out_split is not None:
dims_sum_per_rank = out_split
elif per_rank_split_lengths is not None:
# all the embs have the same dimension
dims_sum_per_rank = []
for s in per_rank_split_lengths:
dims_sum_per_rank.append(s * D_local_list[0])
else:
raise RuntimeError("Need to specify either out_split or per_rank_split_lengths")

a2ai = All2AllVInfo(
dims_sum_per_rank=dims_sum_per_rank,
B_local=B_local,
B_local_list=B_local_list,
D_local_list=D_local_list,
B_global=B_global,
codecs=codecs,
)

if get_use_sync_collectives():
return NoWait(all2allv_sync(group, a2ai, inputs))

myreq = Request(group, device=inputs[0].device)
All2Allv_Req.apply(group, myreq, a2ai, inputs)

return myreq


def all2allv_sync(
pg: dist.ProcessGroup,
a2ai: All2AllVInfo,
inputs: List[Tensor],
) -> List[Tensor]:
input_split_sizes = []
sum_D_local_list = sum(a2ai.D_local_list)
for m in a2ai.B_local_list:
input_split_sizes.append(m * sum_D_local_list)

output_split_sizes = []
for e in a2ai.dims_sum_per_rank:
output_split_sizes.append(a2ai.B_local * e)

input = torch.cat(inputs, dim=1).view([-1])
if a2ai.codecs is not None:
input = a2ai.codecs.forward.encode(input)

with record_function("## alltoallv_bwd_single ##"):
output = torch.ops.torchrec.all_to_all_single(
input,
output_split_sizes,
input_split_sizes,
pg_name(pg),
pg.size(),
get_gradient_division(),
)

if a2ai.codecs is not None:
output = a2ai.codecs.forward.decode(output)

outputs = []
for out in output.split(output_split_sizes):
outputs.append(out.view([a2ai.B_local, -1]))
return outputs


def reduce_scatter_pooled(
inputs: List[Tensor],
group: Optional[dist.ProcessGroup] = None,
Expand Down
5 changes: 1 addition & 4 deletions torchrec/distributed/embedding_dim_bucketer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
from enum import Enum, unique
from typing import Dict, List

from torchrec.distributed.embedding_types import (
EmbeddingComputeKernel,
ShardedEmbeddingTable,
)
from torchrec.distributed.embedding_types import ShardedEmbeddingTable
from torchrec.modules.embedding_configs import DATA_TYPE_NUM_BITS, DataType


Expand Down
33 changes: 0 additions & 33 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,36 +1101,3 @@ def get_tbes_to_register(
self,
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank)


class InferCPUGroupedEmbeddingsLookup(
InferGroupedLookupMixin,
BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]],
TBEToRegisterMixIn,
):
def __init__(
self,
grouped_configs_per_rank: List[List[GroupedEmbeddingConfig]],
world_size: int,
fused_params: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = []

device_type: str = "cpu" if device is None else device.type
for rank in range(world_size):
self._embedding_lookups_per_rank.append(
MetaInferGroupedEmbeddingsLookup(
grouped_configs=grouped_configs_per_rank[rank],
# syntax for torchscript
# pyre-fixme[20]: Argument `index` expected.
device=torch.device(type=device_type),
fused_params=fused_params,
)
)

def get_tbes_to_register(
self,
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank)
4 changes: 1 addition & 3 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,7 @@ def __init__(
)
self._env = env
# output parameters as DTensor in state dict
self._output_dtensor: bool = (
fused_params.get("output_dtensor", False) if fused_params else False
)
self._output_dtensor: bool = env.output_dtensor

sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
module,
Expand Down
12 changes: 6 additions & 6 deletions torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, Optional

import torch

Expand Down Expand Up @@ -55,7 +55,7 @@ def is_fused_param_register_tbe(fused_params: Optional[Dict[str, Any]]) -> bool:


def get_fused_param_tbe_row_alignment(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> Optional[int]:
if fused_params is None or FUSED_PARAM_TBE_ROW_ALIGNMENT not in fused_params:
return None
Expand All @@ -64,7 +64,7 @@ def get_fused_param_tbe_row_alignment(


def fused_param_bounds_check_mode(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> Optional[BoundsCheckMode]:
if fused_params is None or FUSED_PARAM_BOUNDS_CHECK_MODE not in fused_params:
return None
Expand All @@ -73,7 +73,7 @@ def fused_param_bounds_check_mode(


def fused_param_lengths_to_offsets_lookup(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> bool:
if (
fused_params is None
Expand All @@ -85,7 +85,7 @@ def fused_param_lengths_to_offsets_lookup(


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> bool:
return (
fused_params
Expand All @@ -95,7 +95,7 @@ def is_fused_param_quant_state_dict_split_scale_bias(


def tbe_fused_params(
fused_params: Optional[Dict[str, Any]]
fused_params: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
if not fused_params:
return None
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from torchrec.distributed.utils import init_parameters
from torchrec.modules.utils import reset_module_states_post_sharding
from torchrec.types import CacheMixin


def _join_module_path(path: str, name: str) -> str:
Expand Down
5 changes: 1 addition & 4 deletions torchrec/distributed/sharding/cw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
DTensorMetadata,
EmbeddingComputeKernel,
InputDistOutputs,
KJTList,
ShardedEmbeddingTable,
)
from torchrec.distributed.sharding.tw_sharding import (
Expand Down Expand Up @@ -170,7 +169,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
placements=(
Expand All @@ -187,8 +186,6 @@ def _shard(
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

# pyre-fixme [6]
for i, rank in enumerate(info.param_sharding.ranks):
Expand Down
5 changes: 1 addition & 4 deletions torchrec/distributed/sharding/grid_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
placements = (
(Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),)
)
Expand All @@ -246,9 +246,6 @@ def _shard(
stride=info.param.stride(),
)

# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

# Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards
# pyre-fixme [6]
for i, rank in enumerate(info.param_sharding.ranks):
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from torchrec.distributed.embedding_lookup import (
GroupedEmbeddingsLookup,
InferCPUGroupedEmbeddingsLookup,
InferGroupedEmbeddingsLookup,
)
from torchrec.distributed.embedding_sharding import (
Expand Down
4 changes: 1 addition & 3 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
placements = (
(Replicate(), Shard(0)) if self._is_2D_parallel else (Shard(0),)
)
Expand All @@ -197,8 +197,6 @@ def _shard(
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

for rank in range(self._world_size):
tables_per_rank[rank].append(
Expand Down
4 changes: 1 addition & 3 deletions torchrec/distributed/sharding/twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
placements = (Shard(0),)
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
Expand All @@ -175,8 +175,6 @@ def _shard(
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

for rank in range(
table_node * local_size,
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tensor_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import cast, Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple

import torch
from torch import distributed as dist
Expand Down
2 changes: 0 additions & 2 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
)
from torchrec.distributed.types import (
EmbeddingModuleShardingPlan,
EnumerableShardingSpec,
ModuleSharder,
ShardedTensor,
ShardingEnv,
Expand Down Expand Up @@ -288,7 +287,6 @@ def sharding_single_rank_test(
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
) -> None:

with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
# Generate model & inputs.
(global_model, inputs) = gen_model_and_input(
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tests/test_2d_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ def test_sharding_twrw_2D(

self._test_sharding(
world_size=self.WORLD_SIZE,
local_size=self.WORLD_SIZE_2D // 2,
world_size_2D=self.WORLD_SIZE_2D,
node_group_size=self.WORLD_SIZE // 4,
sharders=[
cast(
ModuleSharder[nn.Module],
Expand Down
Loading

0 comments on commit e4eb029

Please sign in to comment.