From c543a41b154a991d50cb6cc8c07dbf46b0d2bdf6 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 6 Aug 2021 14:32:05 -0700 Subject: [PATCH] Use correct input size for splits (#1284) * Use correct input size for splits * Use smarter partitioning --- deepspeed/runtime/zero/tiling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/tiling.py b/deepspeed/runtime/zero/tiling.py index d78fc81515e4..3a78253df496 100644 --- a/deepspeed/runtime/zero/tiling.py +++ b/deepspeed/runtime/zero/tiling.py @@ -14,6 +14,7 @@ def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=Fals """ # Get the size and dimension. last_dim = tensor.dim() - 1 + # Split. tensor_list = torch.split(tensor, partitions, dim=last_dim) # Note: torch.split does not create contiguous tensors by default. @@ -123,8 +124,9 @@ def __init__(self, def forward(self, input_): if self.in_splits > 1 and not self.input_is_already_split: + input_parts = partition(input_.shape[-1], self.in_splits) split_sizes = [ - self.in_parts[p + 1] - self.in_parts[p] for p in range(self.in_splits) + input_parts[p + 1] - input_parts[p] for p in range(self.in_splits) ] inputs = self._split_global_input(input_, split_sizes) elif self.in_splits > 1: