Skip to content

Commit

Permalink
Partition parameters: Minor refactoring of use_secondary_tensor condi…
Browse files Browse the repository at this point in the history
…tion (#4868)

Introduce use_secondary_tensor bool variable to shorten notation
and improve readability.

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
deepcharm and loadams authored Jan 2, 2024
1 parent d873ce6 commit 81cc320
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,9 @@ def all_gather(param_list=None, async_op=False, hierarchy=0):
def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group):
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward

if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
Expand All @@ -1076,13 +1078,11 @@ def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_proc
for i in range(world_size):
partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
if use_secondary_tensor:
instrument_w_nvtx(
torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
else:
use_secondary_tensor = False
instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
Expand Down Expand Up @@ -1118,7 +1118,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
ds_process_group = self.ds_process_group
rank_in_group = self.rank
world_size = self.dp_world_size
use_secondary_tensor = False
use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward
if self.zero_param_process_group and not forward:
ds_process_group = self.zero_param_process_group #intragroup
rank_in_group = self.rank_in_group
Expand Down Expand Up @@ -1149,10 +1149,10 @@ def all_gather_coalesced(params: Iterable[Parameter],
# have an opportunity to avoid some intermediate memory allocations
param, = params
buffer_size = math.ceil(param.ds_numel / world_size) * world_size
if not forward and param.ds_secondary_tensor is not None:
if use_secondary_tensor:
buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized

param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor
param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor
param_buffer = torch.empty(
buffer_size,
dtype=param_ds_tensor.dtype if not quantize else torch.int8,
Expand Down Expand Up @@ -1207,16 +1207,15 @@ def all_gather_coalesced(params: Iterable[Parameter],
else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if params[0].ds_secondary_tensor is not None and not forward:
if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
if use_secondary_tensor:
if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params
Expand Down

0 comments on commit 81cc320

Please sign in to comment.