From 7914b195c0a41ed01823ceac469b58407d9de021 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 14 Dec 2023 16:12:46 -0800 Subject: [PATCH 1/3] fix for tests using torch<2.1 (#4818) Our torch 1.10 tests have been failling since the merge of #4569. This added a `device_type` kwarg to the `torch.random.fork_rng` call. But this is not compatible with older versions of torch. Added in https://github.com/pytorch/pytorch/pull/98069 Fixes #4644, #4503 --- tests/unit/alexnet_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index e3be2be4894d..cf533063d6ec 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -11,6 +11,7 @@ import deepspeed import deepspeed.comm as dist import deepspeed.runtime.utils as ds_utils +from deepspeed.runtime.utils import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec @@ -111,8 +112,11 @@ def cifar_trainset(fp16=False): def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123): - with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], - device_type=get_accelerator().device_name()): + if required_torch_version(min_version=2.1): + fork_kwargs = {"device_type": get_accelerator().device_name()} + else: + fork_kwargs = {} + with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], **fork_kwargs): ds_utils.set_random_seed(seed) # disable dropout From 8998707a2fc8584712a4cb3dc465d02e7d9f50da Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Fri, 15 Dec 2023 13:22:39 -0500 Subject: [PATCH 2/3] Universal Checkpoint for Sequence Parallelism (#4752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR extends the [universal checkpoint](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing) to support DS sequence parallelism and training scenarios where pipeline parallelism is not enabled. The attached Tensorboard chart show a training scenario (validation curve) where a GPT model is pre-trained with data parallelism (4 GPUs), and checkpoints are saved at the 100th and 200th iterations. The checkpoint at the 100th iteration is later loaded for continual pre-training with different configurations (more GPU resources, data parallelism = 4 GPUs, sequence parallelism = 2 GPUs). Screenshot 2023-11-28 at 9 11 55 AM --------- Co-authored-by: Michael Wyatt --- deepspeed/checkpoint/deepspeed_checkpoint.py | 36 +++++++++++++------- deepspeed/checkpoint/ds_to_universal.py | 8 ++--- deepspeed/checkpoint/reshape_3d_utils.py | 2 +- deepspeed/checkpoint/universal_checkpoint.py | 2 +- deepspeed/runtime/zero/stage_1_and_2.py | 4 ++- 5 files changed, 33 insertions(+), 19 deletions(-) diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 77634222d292..8312dddd2fa6 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -34,7 +34,10 @@ class DeepSpeedCheckpoint(object): def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): self.dir = dir - self._validate_folder(dir) + + pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0 + + self._validate_folder(dir, pipeline_parallel) self.zero_checkpoint = ZeROCheckpoint(dir) @@ -193,7 +196,10 @@ def get_final_norm_files(self, tp_index: int) -> list: return self.tp_to_final_norm_map[tp_index] def _build_tp_other_layer_map(self, layer_index: int): - assert layer_index < len(self.layer_files) + data_map = {} + if len(self.layer_files) < 1: + return data_map + assert layer_index <= len(self.layer_files) layer_files = get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) layer_file_partitions = partition_data(layer_files, self.tp_degree) data_map = {i: flist for i, flist in enumerate(layer_file_partitions)} @@ -207,9 +213,13 @@ def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list: def _build_pp_transformer_map(self): data_map = {} - transformer_layers = self.layer_keys[1:-1] - layers_per_pp = len(transformer_layers) // self.pp_degree - data_map = {i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] for i in range(0, self.pp_degree)} + if self.pp_degree > 0: + transformer_layers = self.layer_keys[1:-1] + layers_per_pp = len(transformer_layers) // self.pp_degree + data_map = { + i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] + for i in range(0, self.pp_degree) + } return data_map def _dump_mapping(self, data_map, map_tag=None): @@ -222,9 +232,9 @@ def _build_transformer_file_map(self): transformer_layer_keys = self.layer_keys[1:-1] file_map = {} # XXX: this is not guaranteed - layers_per_pp = len(transformer_layer_keys) // self.pp_degree - if layers_per_pp == 0: - layers_per_pp = 1 + layers_per_pp = 1 + if self.pp_degree > 0: + layers_per_pp = len(transformer_layer_keys) // self.pp_degree #print(f"{transformer_layer_keys} {layers_per_pp}") for key_index, layer_key in enumerate(transformer_layer_keys): pp_index = key_index // layers_per_pp @@ -240,8 +250,8 @@ def _build_transformer_file_map(self): def _sanity_check(self): assert len(self.mp_rank_files) % self.tp_degree == 0 - assert len(self.layer_keys) > 2 assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0 + assert self.zero_checkpoint.num_files % (self.tp_degree) == 0 # XXX: fix me - isn't always the case # only true with --pp-partition-method 'type:transformer|embedding' \ # assert (len(self.layer_keys) - 2) % self.pp_degree == 0 @@ -270,12 +280,14 @@ def _merge_state_dicts(self, sd_list): return merged_sd - def _validate_folder(self, dir): + def _validate_folder(self, dir, pipeline_parallel): basic_folder_validation(dir) file_list = get_files(dir) - - for file_prefix in [MODEL_FILE_PREFIX, LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']: + file_prefix_list = [MODEL_FILE_PREFIX] + if pipeline_parallel: + file_prefix_list.extend([LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']) + for file_prefix in file_prefix_list: ckpt_files = get_files_with_prefix(file_list, file_prefix) assert len( ckpt_files diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 8be187aa89c2..f40c5630899d 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -15,7 +15,7 @@ import shutil import torch import tqdm -# from pprint import pprint +#from pprint import pprint from deepspeed.checkpoint import DeepSpeedCheckpoint from deepspeed.checkpoint import ( @@ -241,9 +241,9 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): _3d_range_list = list( itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) - # pprint(f'{_3d_range_list=}') + #pprint(f'{_3d_range_list=}') work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) - # pprint(f'{work_chunks=}') + #pprint(f'{work_chunks=}') # extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0]) do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) @@ -309,7 +309,7 @@ def main(): print('*** 1. Extracting ZeRO fragments') _extract_zero_shard_files(args, ds_checkpoint, temp_dir) - print('*** 2. Merging slices') + print('*** 2. Merging slices .....') _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) print('*** 3. Saving common optimizer states') diff --git a/deepspeed/checkpoint/reshape_3d_utils.py b/deepspeed/checkpoint/reshape_3d_utils.py index b5bf41e2d160..02b3947624a1 100644 --- a/deepspeed/checkpoint/reshape_3d_utils.py +++ b/deepspeed/checkpoint/reshape_3d_utils.py @@ -81,7 +81,7 @@ def get_model_3d_descriptor(dir): else: tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) dp_degree = max(1, len(zero_file_list) // tp_degree) - pp_degree = 0 + pp_degree = 1 return model_3d_desc(pp_degree, tp_degree, dp_degree) diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index 5849a834cdd3..542d1125c566 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -13,8 +13,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): hp_mapping = self._hp_mapping optim_state_keys = hp_mapping.get_optim_state_keys() hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys + #print(f'{hp_keys=}') checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} - for file in checkpoint_files.values(): assert os.path.isfile(file), f'{file} is not a valid file' diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 1d2d561dbd39..aeb533698af3 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2258,7 +2258,9 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): self._load_global_state(optim_sd) tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) - tp_world_size = self.mpu.get_slice_parallel_world_size() + tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ + else self.mpu.get_tensor_model_parallel_world_size() + for i, _ in enumerate(self.optimizer.param_groups): for lp in self.bit16_groups[i]: if lp._hp_mapping is not None: From 84eaf5ac843234737f0b49e36a818d1aabd1776f Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 15 Dec 2023 10:23:22 -0800 Subject: [PATCH 3/3] Accelerate CI fix (#4819) --- .github/workflows/nv-accelerate-v100.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index d8a03ff34f78..31e413124aed 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -18,7 +18,7 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu111, v100] + runs-on: [self-hosted, nvidia, cu116, v100] steps: - uses: actions/checkout@v3 @@ -28,7 +28,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/cu111 + pip install -U --cache-dir $TORCH_CACHE torch --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())"