Skip to content

Commit

Permalink
Merge branch 'master' into use-all-reduce-for-fetch-params
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Apr 23, 2024
2 parents ec511f4 + c66bc42 commit cde67ab
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 54 deletions.
77 changes: 56 additions & 21 deletions csrc/adam/multi_tensor_adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ typedef enum : int {

using MATH_T = float;

template <typename T>
template <typename T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
Expand All @@ -48,13 +48,13 @@ struct AdamFunctor {
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];

// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;

int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
Expand All @@ -71,7 +71,8 @@ struct AdamFunctor {
n -= chunk_idx * chunk_size;

// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
Expand Down Expand Up @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size,
bias_correction2 = 1 - std::pow(beta2, step);
}

size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) { break; }
}

// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
if (requires_64bit_indexing) {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE,
(int64_t)chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int64_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
} else {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int32_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
}

AT_CUDA_CHECK(cudaGetLastError());
}
10 changes: 5 additions & 5 deletions csrc/adam/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct TensorListMetadata {
};

template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
__global__ void multi_tensor_apply_kernel(int64_t chunk_size,
volatile int* noop_flag,
T tl,
U callable,
Expand All @@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size,
}

template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
void multi_tensor_apply(int64_t block_size,
int64_t chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
Expand Down Expand Up @@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size,
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;

int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;

for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
Expand Down
34 changes: 10 additions & 24 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import argparse
import glob
import itertools
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import os
import re
import shutil
Expand Down Expand Up @@ -292,27 +292,18 @@ def get_matched_sub_params_pattern(name_):
return unmatched_patterns


def _get_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]


def _do_parallel_work(do_work, work_chunks, num_workers):
results = []
if num_workers > 1:
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
with ProcessPoolExecutor(max_workers=num_workers) as executor:
future_list = [executor.submit(do_work, work) for work in work_chunks]
for f in tqdm.tqdm(future_list):
results.append(f.result())
else:
# No parallel pass for unit testing
# We can't create child processes in tests
results = []
for batch in tqdm.tqdm(work_chunks):
res = [do_work(x) for x in batch]
results.extend(res)
for work in tqdm.tqdm(work_chunks):
results.append(do_work(work))
return results


Expand All @@ -321,20 +312,15 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
#pprint(f'{_3d_range_list=}')
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
#pprint(f'{work_chunks=}')

# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
_do_parallel_work(do_work, work_chunks, args.num_extract_workers)
_do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)


def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)
unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers)

# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
Expand Down
17 changes: 13 additions & 4 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads):

# clear gradients
if clear_lp_grads:
lp.grad._zero()
lp.grad.zero_()

@torch.no_grad()
def _update_hp_grads_func(self, clear_lp_grads=False):
Expand Down Expand Up @@ -441,11 +441,20 @@ def clear_hp_grads(self):
self.fp32_groups_has_gradients[i] = [False] * len(group)

def clear_lp_grads(self):

# using zero_() fixed memory address for graph replay
set_to_none = False if self.graph_harvesting else True
zero_grads_list = []
for group in self.bf16_groups:
for param in group:
if param.grad is not None:
# Using zero_() fixed memory address for graph replay
param.grad.zero_()
if set_to_none:
param.grad = None
elif param.grad is not None:
if param.grad.grad_fn is not None:
param.grad.detach_()
zero_grads_list.append(param.grad)
if not set_to_none and len(zero_grads_list) > 0:
torch._foreach_zero_(zero_grads_list)

def state_dict(self):
state_dict = {}
Expand Down

0 comments on commit cde67ab

Please sign in to comment.