Skip to content

Commit

Permalink
Merge branch 'master' into loadams/update-transformers-deepspeed-modules
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Dec 15, 2023
2 parents c1488f5 + 84eaf5a commit fc2eba7
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 23 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/nv-accelerate-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ concurrency:

jobs:
unit-tests:
runs-on: [self-hosted, nvidia, cu111, v100]
runs-on: [self-hosted, nvidia, cu116, v100]

steps:
- uses: actions/checkout@v3
Expand All @@ -28,7 +28,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/cu111
pip install -U --cache-dir $TORCH_CACHE torch --index-url https://download.pytorch.org/whl/cu118
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down
36 changes: 24 additions & 12 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ class DeepSpeedCheckpoint(object):

def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.dir = dir
self._validate_folder(dir)

pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0

self._validate_folder(dir, pipeline_parallel)

self.zero_checkpoint = ZeROCheckpoint(dir)

Expand Down Expand Up @@ -193,7 +196,10 @@ def get_final_norm_files(self, tp_index: int) -> list:
return self.tp_to_final_norm_map[tp_index]

def _build_tp_other_layer_map(self, layer_index: int):
assert layer_index < len(self.layer_files)
data_map = {}
if len(self.layer_files) < 1:
return data_map
assert layer_index <= len(self.layer_files)
layer_files = get_files_with_prefix(self.layer_files, self.layer_keys[layer_index])
layer_file_partitions = partition_data(layer_files, self.tp_degree)
data_map = {i: flist for i, flist in enumerate(layer_file_partitions)}
Expand All @@ -207,9 +213,13 @@ def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list:

def _build_pp_transformer_map(self):
data_map = {}
transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] for i in range(0, self.pp_degree)}
if self.pp_degree > 0:
transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
for i in range(0, self.pp_degree)
}
return data_map

def _dump_mapping(self, data_map, map_tag=None):
Expand All @@ -222,9 +232,9 @@ def _build_transformer_file_map(self):
transformer_layer_keys = self.layer_keys[1:-1]
file_map = {}
# XXX: this is not guaranteed
layers_per_pp = len(transformer_layer_keys) // self.pp_degree
if layers_per_pp == 0:
layers_per_pp = 1
layers_per_pp = 1
if self.pp_degree > 0:
layers_per_pp = len(transformer_layer_keys) // self.pp_degree
#print(f"{transformer_layer_keys} {layers_per_pp}")
for key_index, layer_key in enumerate(transformer_layer_keys):
pp_index = key_index // layers_per_pp
Expand All @@ -240,8 +250,8 @@ def _build_transformer_file_map(self):

def _sanity_check(self):
assert len(self.mp_rank_files) % self.tp_degree == 0
assert len(self.layer_keys) > 2
assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0
assert self.zero_checkpoint.num_files % (self.tp_degree) == 0
# XXX: fix me - isn't always the case
# only true with --pp-partition-method 'type:transformer|embedding' \
# assert (len(self.layer_keys) - 2) % self.pp_degree == 0
Expand Down Expand Up @@ -270,12 +280,14 @@ def _merge_state_dicts(self, sd_list):

return merged_sd

def _validate_folder(self, dir):
def _validate_folder(self, dir, pipeline_parallel):
basic_folder_validation(dir)

file_list = get_files(dir)

for file_prefix in [MODEL_FILE_PREFIX, LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']:
file_prefix_list = [MODEL_FILE_PREFIX]
if pipeline_parallel:
file_prefix_list.extend([LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01'])
for file_prefix in file_prefix_list:
ckpt_files = get_files_with_prefix(file_list, file_prefix)
assert len(
ckpt_files
Expand Down
8 changes: 4 additions & 4 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import shutil
import torch
import tqdm
# from pprint import pprint
#from pprint import pprint

from deepspeed.checkpoint import DeepSpeedCheckpoint
from deepspeed.checkpoint import (
Expand Down Expand Up @@ -241,9 +241,9 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
_3d_range_list = list(
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
# pprint(f'{_3d_range_list=}')
#pprint(f'{_3d_range_list=}')
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
# pprint(f'{work_chunks=}')
#pprint(f'{work_chunks=}')

# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
Expand Down Expand Up @@ -309,7 +309,7 @@ def main():
print('*** 1. Extracting ZeRO fragments')
_extract_zero_shard_files(args, ds_checkpoint, temp_dir)

print('*** 2. Merging slices')
print('*** 2. Merging slices .....')
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)

print('*** 3. Saving common optimizer states')
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/reshape_3d_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_model_3d_descriptor(dir):
else:
tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX))
dp_degree = max(1, len(zero_file_list) // tp_degree)
pp_degree = 0
pp_degree = 1

return model_3d_desc(pp_degree, tp_degree, dp_degree)

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping = self._hp_mapping
optim_state_keys = hp_mapping.get_optim_state_keys()
hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
#print(f'{hp_keys=}')
checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}

for file in checkpoint_files.values():
assert os.path.isfile(file), f'{file} is not a valid file'

Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,7 +2258,9 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
self._load_global_state(optim_sd)

tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size()
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, _ in enumerate(self.optimizer.param_groups):
for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None:
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import deepspeed
import deepspeed.comm as dist
import deepspeed.runtime.utils as ds_utils
from deepspeed.runtime.utils import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec

Expand Down Expand Up @@ -111,8 +112,11 @@ def cifar_trainset(fp16=False):


def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()],
device_type=get_accelerator().device_name()):
if required_torch_version(min_version=2.1):
fork_kwargs = {"device_type": get_accelerator().device_name()}
else:
fork_kwargs = {}
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], **fork_kwargs):
ds_utils.set_random_seed(seed)

# disable dropout
Expand Down

0 comments on commit fc2eba7

Please sign in to comment.