diff --git a/deepspeed/runtime/zero/tiling.py b/deepspeed/runtime/zero/tiling.py index d78fc81515e4..3a78253df496 100644 --- a/deepspeed/runtime/zero/tiling.py +++ b/deepspeed/runtime/zero/tiling.py @@ -14,6 +14,7 @@ def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=Fals """ # Get the size and dimension. last_dim = tensor.dim() - 1 + # Split. tensor_list = torch.split(tensor, partitions, dim=last_dim) # Note: torch.split does not create contiguous tensors by default. @@ -123,8 +124,9 @@ def __init__(self, def forward(self, input_): if self.in_splits > 1 and not self.input_is_already_split: + input_parts = partition(input_.shape[-1], self.in_splits) split_sizes = [ - self.in_parts[p + 1] - self.in_parts[p] for p in range(self.in_splits) + input_parts[p + 1] - input_parts[p] for p in range(self.in_splits) ] inputs = self._split_global_input(input_, split_sizes) elif self.in_splits > 1: