Skip to content

Commit

Permalink
Merge branch 'master' into gma/fix_cpu_inference
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Jan 3, 2024
2 parents a72beea + 81cc320 commit 057b6ff
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 19 deletions.
3 changes: 2 additions & 1 deletion csrc/includes/cpu_adagrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,15 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
6 changes: 3 additions & 3 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream((_streams[_buf_index].stream());
}
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
Expand Down Expand Up @@ -274,14 +273,15 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
3 changes: 2 additions & 1 deletion csrc/includes/cpu_lion.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,15 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
if (half_precision) memcpy_size /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);

_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/contiguous_memory_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _defragment_memory(self):
tensor = self.tensor_map[self.tensor_ids[tensor_addr]]

assert tensor_size == tensor.numel(), \
"Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} "
f"Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} "

assert empty_addr != tensor_addr, \
f"Cannot have same empty address {empty_addr} and tensor address {tensor_addr}"
Expand Down
19 changes: 9 additions & 10 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,9 @@ def all_gather(param_list=None, async_op=False, hierarchy=0):
def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group):
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward

if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
Expand All @@ -1076,13 +1078,11 @@ def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_proc
for i in range(world_size):
partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
if use_secondary_tensor:
instrument_w_nvtx(
torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
else:
use_secondary_tensor = False
instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
Expand Down Expand Up @@ -1118,7 +1118,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
ds_process_group = self.ds_process_group
rank_in_group = self.rank
world_size = self.dp_world_size
use_secondary_tensor = False
use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward
if self.zero_param_process_group and not forward:
ds_process_group = self.zero_param_process_group #intragroup
rank_in_group = self.rank_in_group
Expand Down Expand Up @@ -1149,10 +1149,10 @@ def all_gather_coalesced(params: Iterable[Parameter],
# have an opportunity to avoid some intermediate memory allocations
param, = params
buffer_size = math.ceil(param.ds_numel / world_size) * world_size
if not forward and param.ds_secondary_tensor is not None:
if use_secondary_tensor:
buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized

param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor
param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor
param_buffer = torch.empty(
buffer_size,
dtype=param_ds_tensor.dtype if not quantize else torch.int8,
Expand Down Expand Up @@ -1207,16 +1207,15 @@ def all_gather_coalesced(params: Iterable[Parameter],
else:
partition_sz = sum(p.ds_tensor.ds_numel for p in params)

if params[0].ds_secondary_tensor is not None and not forward:
if use_secondary_tensor:
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=torch.int8,
device=get_accelerator().current_device_name(),
requires_grad=False)

if params[0].ds_secondary_tensor is not None and not forward:
use_secondary_tensor = True
if use_secondary_tensor:
if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
quantized_param = instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def __init__(self,
if self.reduce_scatter:
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"

# param flattened by groups
self.bit16_groups = []
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/runtime/test_ds_config_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class TestBatchConfig(DistributedTest):

def test(self, num_ranks, batch, micro_batch, gas, success):
assert dist.get_world_size() == num_ranks, \
'The test assumes a world size of f{num_ranks}'
f'The test assumes a world size of {num_ranks}'

ds_batch_config = get_test_path('ds_batch_config.json')
ds_config = DeepSpeedConfig(ds_batch_config)
Expand Down

0 comments on commit 057b6ff

Please sign in to comment.