Skip to content

Commit

Permalink
[hotfix] fixed memory usage of shardformer module replacement (hpcait…
Browse files Browse the repository at this point in the history
  • Loading branch information
kurisusnowdeng authored Nov 28, 2023
1 parent 7b789f4 commit 126cf18
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,16 +473,17 @@ def forward(ctx, input_, dim, process_group):
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None


class HookParameter(torch.autograd.Function):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""

@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias)
output = input
return output

@staticmethod
def backward(ctx, grad_output):
weight, bias = ctx.saved_tensors
Expand All @@ -491,13 +492,12 @@ def backward(ctx, grad_output):
if bias is not None:
bias = bias.view(bias.shape)
return grad_output, None, None


def hook_paramter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)



def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
Expand All @@ -522,7 +522,7 @@ def _split(input_, dim=-1, process_group=None):

tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = dist.get_rank(process_group)
output = tensor_list[rank].contiguous()
output = tensor_list[rank].clone().contiguous()

return output

Expand Down
2 changes: 1 addition & 1 deletion colossalai/tensor/d_tensor/comm_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous()
output = torch.narrow(tensor, dim, start, length).clone().contiguous()
return output


Expand Down

0 comments on commit 126cf18

Please sign in to comment.