Skip to content

Commit

Permalink
Merge branch 'master' into zeroOptParamsFlatenning
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Jan 3, 2024
2 parents 8b0d4ce + 81cc320 commit 2e41d9e
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 27 deletions.
3 changes: 2 additions & 1 deletion csrc/includes/cpu_adagrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,15 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
6 changes: 3 additions & 3 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream((_streams[_buf_index].stream());
}
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
Expand Down Expand Up @@ -274,14 +273,15 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
3 changes: 2 additions & 1 deletion csrc/includes/cpu_lion.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,15 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/contiguous_memory_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _defragment_memory(self):
tensor = self.tensor_map[self.tensor_ids[tensor_addr]]

assert tensor_size == tensor.numel(), \
"Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} "
f"Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} "

assert empty_addr != tensor_addr, \
f"Cannot have same empty address {empty_addr} and tensor address {tensor_addr}"
Expand Down
19 changes: 9 additions & 10 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,9 @@ def all_gather(param_list=None, async_op=False, hierarchy=0):
def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group):
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward

if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
Expand All @@ -1076,13 +1078,11 @@ def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_proc
for i in range(world_size):
partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
if use_secondary_tensor:
instrument_w_nvtx(
torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
else:
use_secondary_tensor = False
instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
Expand Down Expand Up @@ -1118,7 +1118,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
ds_process_group = self.ds_process_group
rank_in_group = self.rank
world_size = self.dp_world_size
use_secondary_tensor = False
use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward
if self.zero_param_process_group and not forward:
ds_process_group = self.zero_param_process_group #intragroup
rank_in_group = self.rank_in_group
Expand Down Expand Up @@ -1149,10 +1149,10 @@ def all_gather_coalesced(params: Iterable[Parameter],
# have an opportunity to avoid some intermediate memory allocations
param, = params
buffer_size = math.ceil(param.ds_numel / world_size) * world_size
if not forward and param.ds_secondary_tensor is not None:
if use_secondary_tensor:
buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized

param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor
param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor
param_buffer = torch.empty(
buffer_size,
dtype=param_ds_tensor.dtype if not quantize else torch.int8,
Expand Down Expand Up @@ -1207,16 +1207,15 @@ def all_gather_coalesced(params: Iterable[Parameter],
else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if params[0].ds_secondary_tensor is not None and not forward:
if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
if use_secondary_tensor:
if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params
Expand Down
24 changes: 16 additions & 8 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
param_id = self.get_param_id(p)
if param_id in self.norm_for_param_grads.keys():
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
total_norm += param_norm**2

# Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
Expand All @@ -1337,10 +1337,14 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):

self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)

total_norm = total_norm_cuda[0].item()**(1. / norm_type)
total_norm = total_norm_cuda[0]**(1. / norm_type)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm

return total_norm

Expand Down Expand Up @@ -1669,7 +1673,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):

# Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
total_norm = total_norm_cuda[0]
else:
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")
Expand All @@ -1690,10 +1694,14 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):

self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)

total_norm = total_norm_cuda.item()**(1. / norm_type)
total_norm = total_norm_cuda**(1. / norm_type)

norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm

return total_norm

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def __init__(self,
if self.reduce_scatter:
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"

# param flattened by groups
self.bit16_groups = []
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/runtime/test_ds_config_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class TestBatchConfig(DistributedTest):

def test(self, num_ranks, batch, micro_batch, gas, success):
assert dist.get_world_size() == num_ranks, \
'The test assumes a world size of f{num_ranks}'
f'The test assumes a world size of {num_ranks}'

ds_batch_config = get_test_path('ds_batch_config.json')
ds_config = DeepSpeedConfig(ds_batch_config)
Expand Down

0 comments on commit 2e41d9e

Please sign in to comment.