Skip to content

Commit

Permalink
Use correct input size for splits (#1284)
Browse files Browse the repository at this point in the history
* Use correct input size for splits

* Use smarter partitioning
  • Loading branch information
tjruwase authored Aug 6, 2021
1 parent b1b4175 commit c543a41
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=Fals
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1

# Split.
tensor_list = torch.split(tensor, partitions, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
Expand Down Expand Up @@ -123,8 +124,9 @@ def __init__(self,

def forward(self, input_):
if self.in_splits > 1 and not self.input_is_already_split:
input_parts = partition(input_.shape[-1], self.in_splits)
split_sizes = [
self.in_parts[p + 1] - self.in_parts[p] for p in range(self.in_splits)
input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)
]
inputs = self._split_global_input(input_, split_sizes)
elif self.in_splits > 1:
Expand Down

0 comments on commit c543a41

Please sign in to comment.