diff --git a/deepspeed/runtime/comm/compressed.py b/deepspeed/runtime/comm/compressed.py index 7f8c7395451d..2c5482eb1ad7 100644 --- a/deepspeed/runtime/comm/compressed.py +++ b/deepspeed/runtime/comm/compressed.py @@ -96,7 +96,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro compensated_server_m.add_(server_error) - server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) + server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) server_error.set_(compensated_server_m - server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) diff --git a/deepspeed/runtime/comm/hccl.py b/deepspeed/runtime/comm/hccl.py index 09fb11a731b8..b8639c7da4c9 100644 --- a/deepspeed/runtime/comm/hccl.py +++ b/deepspeed/runtime/comm/hccl.py @@ -83,7 +83,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro compensated_server_m.add_(server_error) - server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) + server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) server_error.set_(compensated_server_m - server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py index 89b6f40a308c..9e7bae816ecd 100644 --- a/deepspeed/runtime/fp16/onebit/lamb.py +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -177,7 +177,7 @@ def step(self, closure=None, grads=None): # This is used to reduce compression error during compression stage. momentum_scales = [] for group in self.param_groups: - momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) / + momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) / np.sqrt(torch.numel(self.state[p]['exp_avg']))).item() for p in group['params']]) united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales]) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 28f91cb9b3ab..9c06567ed100 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2101,7 +2101,7 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups)) + scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups)) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0508766f8896..ed3425167944 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1691,7 +1691,8 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): continue if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): all_norms.append( - torch.norm(g.data.double().detach(), norm_type).to(get_accelerator().current_device_name())) + torch.linalg.vector_norm(g.data.double().detach(), + ord=norm_type).to(get_accelerator().current_device_name())) if len(all_norms) > 0: total_norm = torch.stack(all_norms).square().sum().float() else: @@ -1795,7 +1796,7 @@ def scaled_global_norm(self, norm_type=2): self._average_expert_grad_norms(norm_groups) # calculating L2 norm - return torch.norm(torch.stack(norm_groups), p=norm_type) + return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type) def get_bit16_param_group(self, group_no): bit16_partitions = self.parallel_partitioned_bit16_groups[group_no] diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index ccc43fdf7164..ba5e596e0d6d 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -16,6 +16,71 @@ from deepspeed.utils import groups +def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input): + """ + This function generates the parameters required for `permute` and `reshape` operations, + which are used to process data before and after `all2all` communication. + """ + if batch_dim_idx == 0: + if scatter_idx < 2: + bs, global_seq_len, num_local_head, head_dim = input.shape + pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim] + pre_all2all_permute_idx = (1, 0, 2, 3, 4) + + post_all2all_permute_idx = (1, 2, 0, 3, 4) + post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim] + else: + bs, local_seq_len, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim] + pre_all2all_permute_idx = (2, 0, 1, 3, 4) + + post_all2all_permute_idx = (1, 0, 2, 3, 4) + post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim] + else: + if scatter_idx < 2: + global_seq_len, bs, num_local_head, head_dim = input.shape + pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim] + pre_all2all_permute_idx = None + + post_all2all_permute_idx = (1, 2, 0, 3, 4) + post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim] + else: + local_seq_len, bs, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim] + pre_all2all_permute_idx = (2, 0, 1, 3, 4) + post_all2all_permute_idx = None + post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim] + + return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape + + +def post_all2all(permute_idx, res_shape): + """ + Post-processing function for `all2all` communication. + """ + + def post_func(input): + if permute_idx is not None: + input = input.permute(permute_idx).contiguous() + output = input.reshape(res_shape).contiguous() + + return output + + return post_func + + +def pre_all2all_fun(permute_idx, inp_shape, input): + """ + Pre-processing function for `all2all` communication. + """ + input_t = input.reshape(inp_shape).contiguous() + if permute_idx is not None: + input_t = input_t.permute(permute_idx).contiguous() + return input_t + + def _rotate_half(x): """ change sign so the last dimension becomes [-odd, +even] @@ -43,32 +108,6 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin): return res -def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): - - def post_func(input): - if batch_dim_idx == 0: - # b, s, n, h - if scatter_idx < 2: - output = input.permute(1, 2, 0, 3, 4).contiguous() - output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head, - head_dim).contiguous() - else: - output = input.permute(1, 0, 2, 3, 4).contiguous() - output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size, - head_dim).contiguous() - else: - # s, b, n, h - if scatter_idx < 2: - output = input.permute(1, 2, 0, 3, 4).contiguous() - output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head, - head_dim).contiguous() - else: - output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous() - return output - - return post_func - - def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) @@ -195,39 +234,12 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn assert async_op == False, "uneven head sp does not support async op" return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group) - if batch_dim_idx == 0: - # b, s, n, h - if scatter_idx < 2: - bs, global_seq_len, num_local_head, head_dim = input.shape - input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, - head_dim]).contiguous() - input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() - else: - bs, local_seq_len, num_total_head, head_dim = input.shape - assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" - input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, - head_dim]).contiguous() - input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() - else: - # s, b, n, h - if scatter_idx < 2: - global_seq_len, bs, num_local_head, head_dim = input.shape - input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, - head_dim]).contiguous() - else: - local_seq_len, bs, num_total_head, head_dim = input.shape - assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" - input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, - head_dim]).contiguous() - input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params( + scatter_idx, batch_dim_idx, seq_world_size, input) - if scatter_idx < 2: - post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head, - head_dim) - else: - post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head, - head_dim) + input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input) + post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape) output = torch.empty_like(input_t) work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) @@ -236,7 +248,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn handle[type + '_work'] = work handle[type + '_grad'] = output handle[type + '_post_all2all_func'] = post_all2all_fun - return output + return output.view(post_all2all_res_shape) res = post_all2all_fun(output) return res @@ -271,7 +283,6 @@ def forward(ctx: Any, assert ctx.stream != None res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) get_accelerator().current_stream().wait_stream(ctx.stream) - del ctx.stream.activation_buffer_list # The computation of d o_weight can overlap with the communication of d o_input elif not is_fwd and type in ('q', 'k'):