diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 8a3db1209..830fef412 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -525,7 +525,9 @@ def create_global_tensor_shape_stride_from_metadata( """ size = None if parameter_sharding.sharding_type == ShardingType.COLUMN_WISE.value: - row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[0] # pyre-ignore[16] + row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[ + 0 + ] # pyre-ignore[16] col_dim = 0 for shard in parameter_sharding.sharding_spec.shards: col_dim += shard.shard_sizes[1] @@ -551,4 +553,6 @@ def create_global_tensor_shape_stride_from_metadata( for _ in range(devices_per_node): row_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[0] size = torch.Size([row_dim, col_dim]) - return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) # pyre-ignore[7] + return size, ( + (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) + ) # pyre-ignore[7]