From 81cc32075c721c28378153ee34df11a6d4591db3 Mon Sep 17 00:00:00 2001 From: Max Kovalenko <75629718+deepcharm@users.noreply.github.com> Date: Wed, 3 Jan 2024 01:44:12 +0200 Subject: [PATCH] Partition parameters: Minor refactoring of use_secondary_tensor condition (#4868) Introduce use_secondary_tensor bool variable to shorten notation and improve readability. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .../runtime/zero/partition_parameters.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index cdf7de512b9b..992dcd446ad6 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1064,7 +1064,9 @@ def all_gather(param_list=None, async_op=False, hierarchy=0): def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group): partition_sz = sum(p.ds_tensor.ds_numel for p in params) - if params[0].ds_secondary_tensor is not None and not forward: + use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward + + if use_secondary_tensor: partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) flat_tensor = torch.empty(partition_sz * world_size, @@ -1076,13 +1078,11 @@ def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_proc for i in range(world_size): partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) - if params[0].ds_secondary_tensor is not None and not forward: - use_secondary_tensor = True + if use_secondary_tensor: instrument_w_nvtx( torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params], out=partitions[rank_in_group]) else: - use_secondary_tensor = False instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], out=partitions[rank_in_group]) handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group) @@ -1118,7 +1118,7 @@ def all_gather_coalesced(params: Iterable[Parameter], ds_process_group = self.ds_process_group rank_in_group = self.rank world_size = self.dp_world_size - use_secondary_tensor = False + use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward if self.zero_param_process_group and not forward: ds_process_group = self.zero_param_process_group #intragroup rank_in_group = self.rank_in_group @@ -1149,10 +1149,10 @@ def all_gather_coalesced(params: Iterable[Parameter], # have an opportunity to avoid some intermediate memory allocations param, = params buffer_size = math.ceil(param.ds_numel / world_size) * world_size - if not forward and param.ds_secondary_tensor is not None: + if use_secondary_tensor: buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized - param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor + param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor param_buffer = torch.empty( buffer_size, dtype=param_ds_tensor.dtype if not quantize else torch.int8, @@ -1207,7 +1207,7 @@ def all_gather_coalesced(params: Iterable[Parameter], else: partition_sz = sum(p.ds_tensor.ds_numel for p in params) - if params[0].ds_secondary_tensor is not None and not forward: + if use_secondary_tensor: partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) flat_tensor = torch.empty(partition_sz * world_size, @@ -1215,8 +1215,7 @@ def all_gather_coalesced(params: Iterable[Parameter], device=get_accelerator().current_device_name(), requires_grad=False) - if params[0].ds_secondary_tensor is not None and not forward: - use_secondary_tensor = True + if use_secondary_tensor: if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): quantized_param = instrument_w_nvtx(torch.cat)([ p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params