Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ShardedQuantManagedCollisionEmbeddingCollection #2649

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable


torch.fx.wrap("len")

CACHE_LOAD_FACTOR_STR: str = "cache_load_factor"
Expand All @@ -61,6 +62,15 @@ def _fx_wrap_tensor_to_device_dtype(
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)


@torch.fx.wrap
def _fx_wrap_optional_tensor_to_device_dtype(
t: Optional[torch.Tensor], tensor_device_dtype: torch.Tensor
) -> Optional[torch.Tensor]:
if t is None:
return None
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)


@torch.fx.wrap
def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]:
return (
Expand Down Expand Up @@ -121,6 +131,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
block_sizes: torch.Tensor,
bucketize_pos: bool = False,
block_bucketize_pos: Optional[List[torch.Tensor]] = None,
total_num_blocks: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -142,6 +153,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
bucketize_pos=bucketize_pos,
sequence=True,
block_sizes=block_sizes,
total_num_blocks=total_num_blocks,
my_size=num_buckets,
weights=kjt.weights_or_none(),
max_B=_fx_wrap_max_B(kjt),
Expand Down Expand Up @@ -289,6 +301,7 @@ def bucketize_kjt_inference(
kjt: KeyedJaggedTensor,
num_buckets: int,
block_sizes: torch.Tensor,
total_num_buckets: Optional[torch.Tensor] = None,
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
is_sequence: bool = False,
Expand All @@ -303,6 +316,7 @@ def bucketize_kjt_inference(
Args:
num_buckets (int): number of buckets to bucketize the values into.
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
bucketize_pos (bool): output the changed position of the bucketized values or
not.
block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature.
Expand All @@ -318,6 +332,9 @@ def bucketize_kjt_inference(
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
)
block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
total_num_buckets_new_type = _fx_wrap_optional_tensor_to_device_dtype(
total_num_buckets, kjt.values()
)
unbucketize_permute = None
bucket_mapping = None
if is_sequence:
Expand All @@ -332,6 +349,7 @@ def bucketize_kjt_inference(
kjt,
num_buckets=num_buckets,
block_sizes=block_sizes_new_type,
total_num_blocks=total_num_buckets_new_type,
bucketize_pos=bucketize_pos,
block_bucketize_pos=block_bucketize_row_pos,
)
Expand Down
Loading
Loading