Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Oct 31, 2024
1 parent 3058af9 commit 6f65018
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def all_gather_split(
shards = [
cat(
[
transfer_to_logical_device(shard, i)
shard if i == j else transfer_to_logical_device(shard, i)
for j, shard in enumerate(input.shards)
],
dim=dim,
Expand All @@ -62,7 +62,7 @@ def all_reduce_split_or_unreduced(
elementwise(
torch.add,
*[
transfer_to_logical_device(shard, i)
shard if i == j else transfer_to_logical_device(shard, i)
for j, shard in enumerate(input.shards)
],
)
Expand Down Expand Up @@ -1137,7 +1137,10 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor:
@unshard.override(UnreducedTensor)
def unshard_unreduced(input: UnreducedTensor) -> Tensor:
shards = input.shards
shards = [transfer_to_logical_device(shard, 0) for i, shard in enumerate(shards)]
shards = [
shard if i == 0 else transfer_to_logical_device(shard, 0)
for i, shard in enumerate(shards)
]
return elementwise(torch.add, *shards)


Expand Down

0 comments on commit 6f65018

Please sign in to comment.