Skip to content

Commit

Permalink
Merge branch 'master' into add-safetensors-support
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Nov 9, 2023
2 parents ce3a075 + da652d0 commit 64bf60e
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 98 deletions.
47 changes: 28 additions & 19 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
# Ignore Python compiled files
*.pyc
.idea/
*~
*.swp

# Ignore IDE-specific files and directories
.idea/ # JetBrains IDE settings
.vscode/ # Visual Studio Code settings
.theia/ # Theia IDE settings

# Ignore temporary and backup files
*~ # General backup files
*.swp # Vim swap files

# Ignore log files
*.log

# Ignore a specific generated file
deepspeed/git_version_info_installed.py

# Ignore Python bytecode cache
__pycache__

# Build + installation data
build/
dist/
*.so
deepspeed.egg-info/
build.txt

# Website
docs/_site/
docs/build
build/ # Build artifacts
dist/ # Distribution files
*.so # Compiled shared objects
deepspeed.egg-info/ # Deepspeed package info
build.txt # Build information

# Website generated files
docs/_site/ # Jekyll generated site
docs/build # Generated documentation
docs/code-docs/source/_build
docs/code-docs/_build
docs/code-docs/build
.sass-cache/
.jekyll-cache/
.sass-cache/ # SASS cache
.jekyll-cache/ # Jekyll cache
.jekyll-metadata

# Testing data
tests/unit/saved_checkpoint/

# Dev/IDE data
.vscode
.theia
tests/unit/saved_checkpoint/ # Saved checkpoints for testing
12 changes: 10 additions & 2 deletions csrc/deepspeed4science/evoformer_attn/gemm/custom_mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,18 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
}

CUTLASS_DEVICE
bool set_prologue_done(bool value) { prologue_done_ = value; }
bool set_prologue_done(bool value)
{
prologue_done_ = value;
return true;
}

CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; }
bool set_zero_outside_bounds(bool value)
{
zero_outside_bounds_ = value;
return true;
}

template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2

# Vocabulary padding
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
VOCAB_TENSOR = 'vocab_tensor'
PADDED_VOCAB_SIZE = 'padded_vocab_size'
ORIGINAL_VOCAB_SIZE = 'original_vocab_size'

Expand Down
60 changes: 41 additions & 19 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
PARAM_SHAPES,
PARAM,
CAT_DIM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
ORIGINAL_VOCAB_SIZE,
VOCAB_TENSOR,
UNIVERSAL_CHECKPOINT_INFO,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
Expand Down Expand Up @@ -55,6 +54,10 @@ def parse_arguments():
parser.add_argument('--keep_temp_folder',
action='store_true',
help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.')
parser.add_argument('--no_strict',
dest='strict',
action='store_false',
help='Do not perform validity checks on converted checkpoint.')
args = parser.parse_args()
print(f'args = {args}')
return args
Expand Down Expand Up @@ -149,15 +152,8 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
return slices


def _get_vocab_divisibility_padding_tensor(universal_checkpoint_info, padded_vocab_tensor):
original_vocab_size = universal_checkpoint_info.get(ORIGINAL_VOCAB_SIZE)
if padded_vocab_tensor.shape[0] > original_vocab_size:
return padded_vocab_tensor[-1]
else:
return torch.zeros(padded_vocab_tensor.shape[1])


def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):

name, shape = name_and_shape
slice_base_path = os.path.join(slice_dir, name)
param_base_path = os.path.join(dir, name)
Expand All @@ -167,39 +163,53 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
vocabulary_parameters)

def get_matched_pattern(patterns_, name_):
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}'
if matched_:
pattern_ = matched_[0]
unmatched_patterns.discard(pattern_)
return pattern_
return None

for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")

#print(f"Expected shape: {shape}")
#print(f"Fragment sizes:", list(frag.shape for frag in slices))
ckpt_dict = {}
if any(re.match(pattern, name) for pattern in replicated_parameters):
if get_matched_pattern(replicated_parameters, name):
if len(slices) > 1:
assert all([slices[0].equal(other_slice) for other_slice in slices[1:]])
param = slices[0]
# print(f'replicate {name} using first slice')
elif any(re.match(pattern, name) for pattern in parameters_to_average):
elif get_matched_pattern(parameters_to_average, name):
param = sum(slices) / len(slices)
# print(f'merge {name} using average')
else:
cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0
cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
# print(f"merge {name} with CAT DIM: {cat_dim}")
param = torch.cat(slices, dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim

if any(re.match(pattern, name) for pattern in vocabulary_parameters):
if get_matched_pattern(vocabulary_parameters, name):
#print(f"Before {param.shape=}")
# strip padding
#param = _strip_vocab_padding(ds_checkpoint, param)
ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor(
universal_checkpoint_info, param)
original_vocab_size = universal_checkpoint_info['original_vocab_size']
param = param[:original_vocab_size, :]
ckpt_dict[VOCAB_TENSOR] = True
#print(f"After {param.shape=}")

#print(f"Final shape: {param.shape}")
ckpt_dict[PARAM] = param
_save_checkpoint(final_path, ckpt_dict)

return unmatched_patterns


def _get_chunks(l, n):
for i in range(0, len(l), n):
Expand All @@ -208,10 +218,13 @@ def _get_chunks(l, n):

def _do_parallel_work(do_work, work_chunks, num_workers):
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
pool.map(do_work, batch)
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
return results


def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
Expand All @@ -232,7 +245,16 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
_do_parallel_work(do_work, work_chunks, args.num_merge_workers)
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)

# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
sets = [set(lst) for lst in unmatched_patterns_lists]
unmatched_patterns = list(set.intersection(*sets))
if args.strict:
assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices'
elif unmatched_patterns:
print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')


def _save_optimizer_state(args, ds_checkpoint):
Expand Down
15 changes: 5 additions & 10 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import torch
import types
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, CAT_DIM)
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM)


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
Expand Down Expand Up @@ -43,21 +43,16 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor = ckpt_dict.get(VOCAB_DIVISIBILITY_PADDING_TENSOR, None)
if vocab_divisibility_padding_tensor is not None:
is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False)
if is_vocab_tensor:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
padded_target_vocab_size = self.shape[0] * tp_world_size
assert padded_target_vocab_size >= full_hp_param.shape[0], \
f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}'
if padded_target_vocab_size > full_hp_param.shape[0]:
# Need to expand
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else:
# Need to shrink or keep the same
full_hp_param = full_hp_param[:padded_target_vocab_size, :]

full_param_numel = full_hp_param.numel()
tp_slice_numel = self.numel()
Expand Down
28 changes: 16 additions & 12 deletions deepspeed/ops/deepspeed4science/evoformer_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _attention(Q, K, V, bias1, bias2):
return O, lse


def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2):
def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2, bias1_grad, bias2_grad):
assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value"
dQ = torch.empty_like(Q, dtype=Q.dtype)
dK = torch.empty_like(K, dtype=K.dtype)
Expand All @@ -44,8 +44,14 @@ def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2):
if kernel_ is None:
kernel_ = EvoformerAttnBuilder().load()
delta = torch.empty_like(lse)
dB1 = torch.zeros_like(bias1, dtype=torch.float32)
dB2 = torch.zeros_like(bias2, dtype=torch.float32)
if bias1_grad:
dB1 = torch.zeros_like(bias1, dtype=torch.float32)
else:
dB1 = torch.tensor([], dtype=torch.float32, device=bias1.device)
if bias2_grad:
dB2 = torch.zeros_like(bias2, dtype=torch.float32)
else:
dB2 = torch.tensor([], dtype=torch.float32, device=bias2.device)
kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2)
return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype)

Expand All @@ -69,10 +75,12 @@ def forward(ctx, q, k, v, bias1=None, bias2=None):
@staticmethod
def backward(ctx, grad_output):
q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors
dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2)
if bias1.numel() == 0:
is_b1_grad = bias1.numel() != 0 and ctx.needs_input_grad[3]
is_b2_grad = bias2.numel() != 0 and ctx.needs_input_grad[4]
dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2, is_b1_grad, is_b2_grad)
if not is_b1_grad:
dB1 = None
if bias2.numel() == 0:
if not is_b2_grad:
dB2 = None
return dQ, dK, dV, dB1, dB2

Expand All @@ -90,13 +98,9 @@ def DS4Sci_EvoformerAttention(Q, K, V, biases):
bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2])

if biases[0] is not None:
assert biases[0].shape == bias_1_shape(Q)
else:
biases[0] = Q.new_zeros(bias_1_shape(Q))
assert biases[0].shape == bias_1_shape(Q), "bias1 shape is incorrect"

if biases[1] is not None:
assert biases[1].shape == bias_2_shape(Q)
else:
biases[1] = Q.new_zeros(bias_2_shape(Q))
assert biases[1].shape == bias_2_shape(Q), "bias2 shape is incorrect"

return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1])
8 changes: 8 additions & 0 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def model_parallel_cuda_manual_seed(seed):
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)


def model_parallel_reconfigure_tp_seed(seed):
global mpu
tp_rank = bwc_tensor_model_parallel_rank(mpu)
model_parallel_seed = seed + 2718 + tp_rank
with _CUDA_RNG_STATE_TRACKER.fork():
get_accelerator().manual_seed(model_parallel_seed)


def get_partition_start(item):
global mp_rank, mp_size, mp_group
size = item.numel()
Expand Down
2 changes: 0 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(self,
self.fp32_groups_gradient_flat_partition = []
self.fp32_groups_has_gradients = []

self.step_count = 0
self.group_paddings = []

if self.using_real_optimizer:
Expand Down Expand Up @@ -252,7 +251,6 @@ def step(self, closure=None):
self.update_lp_params()

self.clear_hp_grads()
self.step_count += 1

def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
"""Perform a backward pass and copy the low-precision gradients to the
Expand Down
20 changes: 20 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2726,6 +2726,8 @@ def load_checkpoint(self,

if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)

return load_path, client_states

Expand Down Expand Up @@ -2903,6 +2905,24 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
return True

def update_optimizer_step(self, step):

def set_step(d):
if isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step

optimizer = self.optimizer
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
if 'step' in group:
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0 and 'step' in state[p]:
set_step(state[p])

def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
zero_ckpt_names = []
for dp_rank in range(dp_world_size):
Expand Down
3 changes: 0 additions & 3 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(self,

self.clip_grad = clip_grad
self.norm_type = 2
self.step_count = 0

if required_torch_version(max_version=0.4):
self.clip_grad_norm = torch.nn.utils.clip_grad_norm
Expand Down Expand Up @@ -289,8 +288,6 @@ def step(self, closure=None):

self.timers.log(STEP_TIMERS)

self.step_count += 1

return self.overflow

def _get_norm_with_moe_layers(self, all_groups_norm):
Expand Down
Loading

0 comments on commit 64bf60e

Please sign in to comment.