Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Jan 15, 2025
1 parent cd01543 commit bd51208
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
1 change: 1 addition & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def conv2d_default(
conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default)
conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default)


# Einsum
def mk_menk_men(inputs, weights):
# batch dims: m, lhs pdims: none, lhs rdims: k, rhs pdims: en, rhs rdims: k
Expand Down
24 changes: 20 additions & 4 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def all_gather_split(
shards = [
cat(
[
barrier_on_logical_device(shard, i) if i == j else transfer_to_logical_device(shard, i)
(
barrier_on_logical_device(shard, i)
if i == j
else transfer_to_logical_device(shard, i)
)
for j, shard in enumerate(input.shards)
],
dim=dim,
Expand All @@ -63,7 +67,11 @@ def all_reduce_split_or_unreduced(
functools.reduce(
lambda x, y: elementwise(torch.add, x, y),
[
barrier_on_logical_device(shard, i) if i == j else transfer_to_logical_device(shard, i)
(
barrier_on_logical_device(shard, i)
if i == j
else transfer_to_logical_device(shard, i)
)
for j, shard in enumerate(input.shards)
],
)
Expand Down Expand Up @@ -1085,7 +1093,11 @@ def reshard_like_unreduced_to_replicated(
@sharded_cat.override(SplitPrimitiveTensor)
def sharded_cat_unsharded(tensor: SplitPrimitiveTensor):
shard_ts = [
transfer_to_logical_device(shard.as_torch(), 0) if i != 0 else barrier_on_logical_device(shard.as_torch(), 0)
(
transfer_to_logical_device(shard.as_torch(), 0)
if i != 0
else barrier_on_logical_device(shard.as_torch(), 0)
)
for i, shard in enumerate(tensor.shards)
]
return torch.cat(shard_ts, dim=tensor.shard_dim)
Expand Down Expand Up @@ -1177,7 +1189,11 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor:
def unshard_unreduced(input: UnreducedTensor) -> Tensor:
shards = input.shards
shards = [
barrier_on_logical_device(shard, i) if i == 0 else transfer_to_logical_device(shard, 0)
(
barrier_on_logical_device(shard, i)
if i == 0
else transfer_to_logical_device(shard, 0)
)
for i, shard in enumerate(shards)
]
return functools.reduce(lambda x, y: elementwise(torch.add, x, y), shards)
Expand Down
9 changes: 5 additions & 4 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def _all_reduce_trampoline(d: SignatureDispatcher, tensor: AnyTensor):


@overridable
def cat(tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0) -> AnyTensor:
...
def cat(
tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0
) -> AnyTensor: ...


@cat.trampoline
Expand Down Expand Up @@ -957,8 +958,7 @@ def _sharded_cat_trampoline(d: SignatureDispatcher, maybe_sharded: AnyTensor):


@overridable
def sharded_sum(maybe_sharded: AnyTensor):
...
def sharded_sum(maybe_sharded: AnyTensor): ...


@sharded_sum.trampoline
Expand Down Expand Up @@ -1031,6 +1031,7 @@ def _barrier_on_logical_device_trampoline(
else:
d.fail(tensors)


@overridable
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
"""Transfer the tensor to a device with ordinal `ordinal`."""
Expand Down

0 comments on commit bd51208

Please sign in to comment.