Skip to content

Commit

Permalink
fix formatting of create_global_tensor_shape_stride_from_metadata
Browse files Browse the repository at this point in the history
Summary:
tsia
fixing: https://github.com/pytorch/torchrec/actions/runs/12640181611/job/35220063046

Differential Revision: D67875613
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Jan 6, 2025
1 parent 504642a commit 8828f65
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ def create_global_tensor_shape_stride_from_metadata(
"""
size = None
if parameter_sharding.sharding_type == ShardingType.COLUMN_WISE.value:
row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[0] # pyre-ignore[16]
row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[
0
] # pyre-ignore[16]
col_dim = 0
for shard in parameter_sharding.sharding_spec.shards:
col_dim += shard.shard_sizes[1]
Expand All @@ -551,4 +553,6 @@ def create_global_tensor_shape_stride_from_metadata(
for _ in range(devices_per_node):
row_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[0]
size = torch.Size([row_dim, col_dim])
return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) # pyre-ignore[7]
return size, (
(size[1], 1) if size else (torch.Size([0, 0]), (0, 1))
) # pyre-ignore[7]

0 comments on commit 8828f65

Please sign in to comment.