From 7914b195c0a41ed01823ceac469b58407d9de021 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 14 Dec 2023 16:12:46 -0800 Subject: [PATCH 01/18] fix for tests using torch<2.1 (#4818) Our torch 1.10 tests have been failling since the merge of #4569. This added a `device_type` kwarg to the `torch.random.fork_rng` call. But this is not compatible with older versions of torch. Added in https://github.com/pytorch/pytorch/pull/98069 Fixes #4644, #4503 --- tests/unit/alexnet_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index e3be2be4894d..cf533063d6ec 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -11,6 +11,7 @@ import deepspeed import deepspeed.comm as dist import deepspeed.runtime.utils as ds_utils +from deepspeed.runtime.utils import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec @@ -111,8 +112,11 @@ def cifar_trainset(fp16=False): def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123): - with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], - device_type=get_accelerator().device_name()): + if required_torch_version(min_version=2.1): + fork_kwargs = {"device_type": get_accelerator().device_name()} + else: + fork_kwargs = {} + with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], **fork_kwargs): ds_utils.set_random_seed(seed) # disable dropout From 8998707a2fc8584712a4cb3dc465d02e7d9f50da Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Fri, 15 Dec 2023 13:22:39 -0500 Subject: [PATCH 02/18] Universal Checkpoint for Sequence Parallelism (#4752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR extends the [universal checkpoint](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing) to support DS sequence parallelism and training scenarios where pipeline parallelism is not enabled. The attached Tensorboard chart show a training scenario (validation curve) where a GPT model is pre-trained with data parallelism (4 GPUs), and checkpoints are saved at the 100th and 200th iterations. The checkpoint at the 100th iteration is later loaded for continual pre-training with different configurations (more GPU resources, data parallelism = 4 GPUs, sequence parallelism = 2 GPUs). Screenshot 2023-11-28 at 9 11 55 AM --------- Co-authored-by: Michael Wyatt --- deepspeed/checkpoint/deepspeed_checkpoint.py | 36 +++++++++++++------- deepspeed/checkpoint/ds_to_universal.py | 8 ++--- deepspeed/checkpoint/reshape_3d_utils.py | 2 +- deepspeed/checkpoint/universal_checkpoint.py | 2 +- deepspeed/runtime/zero/stage_1_and_2.py | 4 ++- 5 files changed, 33 insertions(+), 19 deletions(-) diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 77634222d292..8312dddd2fa6 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -34,7 +34,10 @@ class DeepSpeedCheckpoint(object): def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): self.dir = dir - self._validate_folder(dir) + + pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0 + + self._validate_folder(dir, pipeline_parallel) self.zero_checkpoint = ZeROCheckpoint(dir) @@ -193,7 +196,10 @@ def get_final_norm_files(self, tp_index: int) -> list: return self.tp_to_final_norm_map[tp_index] def _build_tp_other_layer_map(self, layer_index: int): - assert layer_index < len(self.layer_files) + data_map = {} + if len(self.layer_files) < 1: + return data_map + assert layer_index <= len(self.layer_files) layer_files = get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) layer_file_partitions = partition_data(layer_files, self.tp_degree) data_map = {i: flist for i, flist in enumerate(layer_file_partitions)} @@ -207,9 +213,13 @@ def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list: def _build_pp_transformer_map(self): data_map = {} - transformer_layers = self.layer_keys[1:-1] - layers_per_pp = len(transformer_layers) // self.pp_degree - data_map = {i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] for i in range(0, self.pp_degree)} + if self.pp_degree > 0: + transformer_layers = self.layer_keys[1:-1] + layers_per_pp = len(transformer_layers) // self.pp_degree + data_map = { + i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] + for i in range(0, self.pp_degree) + } return data_map def _dump_mapping(self, data_map, map_tag=None): @@ -222,9 +232,9 @@ def _build_transformer_file_map(self): transformer_layer_keys = self.layer_keys[1:-1] file_map = {} # XXX: this is not guaranteed - layers_per_pp = len(transformer_layer_keys) // self.pp_degree - if layers_per_pp == 0: - layers_per_pp = 1 + layers_per_pp = 1 + if self.pp_degree > 0: + layers_per_pp = len(transformer_layer_keys) // self.pp_degree #print(f"{transformer_layer_keys} {layers_per_pp}") for key_index, layer_key in enumerate(transformer_layer_keys): pp_index = key_index // layers_per_pp @@ -240,8 +250,8 @@ def _build_transformer_file_map(self): def _sanity_check(self): assert len(self.mp_rank_files) % self.tp_degree == 0 - assert len(self.layer_keys) > 2 assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0 + assert self.zero_checkpoint.num_files % (self.tp_degree) == 0 # XXX: fix me - isn't always the case # only true with --pp-partition-method 'type:transformer|embedding' \ # assert (len(self.layer_keys) - 2) % self.pp_degree == 0 @@ -270,12 +280,14 @@ def _merge_state_dicts(self, sd_list): return merged_sd - def _validate_folder(self, dir): + def _validate_folder(self, dir, pipeline_parallel): basic_folder_validation(dir) file_list = get_files(dir) - - for file_prefix in [MODEL_FILE_PREFIX, LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']: + file_prefix_list = [MODEL_FILE_PREFIX] + if pipeline_parallel: + file_prefix_list.extend([LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']) + for file_prefix in file_prefix_list: ckpt_files = get_files_with_prefix(file_list, file_prefix) assert len( ckpt_files diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 8be187aa89c2..f40c5630899d 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -15,7 +15,7 @@ import shutil import torch import tqdm -# from pprint import pprint +#from pprint import pprint from deepspeed.checkpoint import DeepSpeedCheckpoint from deepspeed.checkpoint import ( @@ -241,9 +241,9 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): _3d_range_list = list( itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) - # pprint(f'{_3d_range_list=}') + #pprint(f'{_3d_range_list=}') work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) - # pprint(f'{work_chunks=}') + #pprint(f'{work_chunks=}') # extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0]) do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) @@ -309,7 +309,7 @@ def main(): print('*** 1. Extracting ZeRO fragments') _extract_zero_shard_files(args, ds_checkpoint, temp_dir) - print('*** 2. Merging slices') + print('*** 2. Merging slices .....') _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) print('*** 3. Saving common optimizer states') diff --git a/deepspeed/checkpoint/reshape_3d_utils.py b/deepspeed/checkpoint/reshape_3d_utils.py index b5bf41e2d160..02b3947624a1 100644 --- a/deepspeed/checkpoint/reshape_3d_utils.py +++ b/deepspeed/checkpoint/reshape_3d_utils.py @@ -81,7 +81,7 @@ def get_model_3d_descriptor(dir): else: tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) dp_degree = max(1, len(zero_file_list) // tp_degree) - pp_degree = 0 + pp_degree = 1 return model_3d_desc(pp_degree, tp_degree, dp_degree) diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index 5849a834cdd3..542d1125c566 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -13,8 +13,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): hp_mapping = self._hp_mapping optim_state_keys = hp_mapping.get_optim_state_keys() hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys + #print(f'{hp_keys=}') checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} - for file in checkpoint_files.values(): assert os.path.isfile(file), f'{file} is not a valid file' diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 1d2d561dbd39..aeb533698af3 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2258,7 +2258,9 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): self._load_global_state(optim_sd) tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) - tp_world_size = self.mpu.get_slice_parallel_world_size() + tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ + else self.mpu.get_tensor_model_parallel_world_size() + for i, _ in enumerate(self.optimizer.param_groups): for lp in self.bit16_groups[i]: if lp._hp_mapping is not None: From 84eaf5ac843234737f0b49e36a818d1aabd1776f Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 15 Dec 2023 10:23:22 -0800 Subject: [PATCH 03/18] Accelerate CI fix (#4819) --- .github/workflows/nv-accelerate-v100.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index d8a03ff34f78..31e413124aed 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -18,7 +18,7 @@ concurrency: jobs: unit-tests: - runs-on: [self-hosted, nvidia, cu111, v100] + runs-on: [self-hosted, nvidia, cu116, v100] steps: - uses: actions/checkout@v3 @@ -28,7 +28,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/cu111 + pip install -U --cache-dir $TORCH_CACHE torch --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" From 4a6e0c06240b45185709ac4a2902ec42518049d2 Mon Sep 17 00:00:00 2001 From: jxysoft Date: Sat, 16 Dec 2023 05:01:12 +0800 Subject: [PATCH 04/18] =?UTF-8?q?fix=20[BUG]=20'DeepSpeedGPTInference'=20o?= =?UTF-8?q?bject=20has=20no=20attribute=20'dtype'=20for=E2=80=A6=20(#4814)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepspeed/model_implementations/transformers/ds_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index a41df58ad059..d87d0de997b5 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -163,7 +163,7 @@ def forward( if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \ and input.dtype == torch.float: - target_dtype = torch.half if self.dtype == torch.int8 else self.dtype + target_dtype = torch.half if self.config.dtype == torch.int8 else self.config.dtype input = input.to(target_dtype) with torch.no_grad(): From d1f1d45f4b4eb86bf5b82ed617f09f528bb00d11 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 15 Dec 2023 13:02:17 -0800 Subject: [PATCH 05/18] Update broken link in docs (#4822) resolves #4821 --- docs/_tutorials/getting-started.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index 8d2bbf2d9964..f9a4cfdc68b4 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -8,7 +8,7 @@ tags: getting-started ## Installation * Installing is as simple as `pip install deepspeed`, [see more details](/tutorials/advanced-install/). -* To get started with DeepSpeed on AzureML, please see the [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/python-sdk/workflows/train/deepspeed) +* To get started with DeepSpeed on AzureML, please see the [AzureML Examples GitHub](https://github.com/Azure/azureml-examples/tree/main/cli/jobs/deepspeed) * DeepSpeed has direct integrations with [HuggingFace Transformers](https://github.com/huggingface/transformers) and [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). HuggingFace Transformers users can now easily accelerate their models with DeepSpeed through a simple ``--deepspeed`` flag + config file [See more details](https://huggingface.co/docs/transformers/main_classes/deepspeed). PyTorch Lightning provides easy access to DeepSpeed through the Lightning Trainer [See more details](https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html?highlight=deepspeed#deepspeed). * DeepSpeed on AMD can be used via our [ROCm images](https://hub.docker.com/r/deepspeed/rocm501/tags), e.g., `docker pull deepspeed/rocm501:ds060_pytorch110`. From b83b1c2e1c4dc4c91c4ad78773dc2232ca9f7070 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Fri, 15 Dec 2023 14:12:50 -0800 Subject: [PATCH 06/18] Update imports from Transformers (#4817) --- requirements/requirements-dev.txt | 2 +- requirements/requirements-inf.txt | 2 +- tests/unit/inference/quantization/test_intX_quantization.py | 4 ++-- tests/unit/runtime/zero/test_zero_nesting_init.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 078386c457bd..7204eead5864 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -16,5 +16,5 @@ sphinx sphinx-rtd-theme tensorboard torchvision -transformers +transformers>=4.32.1 wandb diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt index 848a7f7a485d..27371e623f26 100644 --- a/requirements/requirements-inf.txt +++ b/requirements/requirements-inf.txt @@ -1,5 +1,5 @@ google lm-eval==0.3.0 protobuf -transformers +transformers>=4.32.1 transformers[sentencepiece] diff --git a/tests/unit/inference/quantization/test_intX_quantization.py b/tests/unit/inference/quantization/test_intX_quantization.py index 56df2b232d15..fd6a8e5ad2e1 100644 --- a/tests/unit/inference/quantization/test_intX_quantization.py +++ b/tests/unit/inference/quantization/test_intX_quantization.py @@ -55,7 +55,7 @@ def quantization_test_helper(pre_quant_type: torch.dtype, num_bits: int): def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int): import deepspeed - from transformers.deepspeed import HfDeepSpeedConfig + from transformers.integrations.deepspeed import HfDeepSpeedConfig def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict: GB = 1 << 30 @@ -172,7 +172,7 @@ def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: b def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: bool, bits: int): import deepspeed - from transformers.deepspeed import HfDeepSpeedConfig + from transformers.integrations.deepspeed import HfDeepSpeedConfig def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict: GB = 1 << 30 diff --git a/tests/unit/runtime/zero/test_zero_nesting_init.py b/tests/unit/runtime/zero/test_zero_nesting_init.py index 143e7e997b13..15d82fd8be00 100644 --- a/tests/unit/runtime/zero/test_zero_nesting_init.py +++ b/tests/unit/runtime/zero/test_zero_nesting_init.py @@ -8,7 +8,7 @@ from unit.common import DistributedTest from transformers import VisionEncoderDecoderModel -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig import deepspeed From bc1b5a6c06049f39d1e5c18bbe0f29a09e11f4a3 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 15 Dec 2023 14:39:25 -0800 Subject: [PATCH 07/18] Minor updates to CI workflows (#4823) --- .github/workflows/amd-mi100.yml | 56 ------------------------------ .github/workflows/auto-sync.yml | 59 -------------------------------- .github/workflows/formatting.yml | 2 +- .github/workflows/python.yml | 2 +- 4 files changed, 2 insertions(+), 117 deletions(-) delete mode 100644 .github/workflows/amd-mi100.yml delete mode 100644 .github/workflows/auto-sync.yml diff --git a/.github/workflows/amd-mi100.yml b/.github/workflows/amd-mi100.yml deleted file mode 100644 index 7ad0f4178db4..000000000000 --- a/.github/workflows/amd-mi100.yml +++ /dev/null @@ -1,56 +0,0 @@ -name: amd-mi100 - -on: - schedule: - - cron: "0 0 * * *" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - amd-tests: - # The type of runner that the job will run on - runs-on: [self-hosted, amd, mi100] - - # Steps represent a sequence of tasks that will be executed as part of the job - steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v3 - - - id: setup-venv - uses: ./.github/workflows/setup-venv - - - name: Install pytorch - run: | - pip install --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 - python -c "import torch; print('torch:', torch.__version__, torch)" - python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - - - name: Install transformers - run: | - git clone https://github.com/huggingface/transformers - cd transformers - # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout 1cc453d33 - git rev-parse --short HEAD - pip install . - - # Runs a set of commands using the runners shell - - name: Install deepspeed - run: | - pip install .[dev,1bit,autotuning] - #python -c "from deepspeed.env_report import cli_main; cli_main()" - ds_report - - - name: Python environment - run: | - pip list - - # Runs a set of commands using the runners shell - - name: Unit tests - run: | - unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - cd tests - pytest $PYTEST_OPTS -n 4 --verbose unit/ - pytest $PYTEST_OPTS -m 'sequential' unit/ diff --git a/.github/workflows/auto-sync.yml b/.github/workflows/auto-sync.yml deleted file mode 100644 index bfbf5a2ae37a..000000000000 --- a/.github/workflows/auto-sync.yml +++ /dev/null @@ -1,59 +0,0 @@ -name: AutoSync - -on: - push: - branches: - - 'master' - -jobs: - - Create-PR: - runs-on: ubuntu-20.04 - - steps: - - uses: actions/checkout@v3 - with: - token: ${{ secrets.GHP_TOKEN }} - repository: ${{ secrets.DST_REPO }} - ref: ${{ secrets.DST_REPO_BRANCH }} - path: dst-repo - - - name: Get PR data - run: | - echo "REPO=${{ github.repository }}" >> $GITHUB_ENV - echo "COMMIT_SHA=${{ github.event.after }}" >> $GITHUB_ENV - echo "SHORT_SHA=$(echo ${{ github.event.after }} | cut -c1-8)" >> $GITHUB_ENV - echo "USERNAME=${{ github.event.head_commit.author.username }}" >> $GITHUB_ENV - echo "USER_EMAIL=${{ github.event.head_commit.author.username }}@users.noreply.github.com" >> $GITHUB_ENV - echo "PR_NAME=$(echo '${{ github.event.head_commit.message }}' | head -1 | sed 's|#|${{ github.repository }}#|g')" >> $GITHUB_ENV - - - name: Cherry pick commit - continue-on-error: true - run: | - cd dst-repo - git config --global user.name ${{ env.USERNAME }} - git config --global user.email ${{ env.USER_EMAIL }} - git fetch https://github.com/${{ env.REPO }}.git master - git cherry-pick FETCH_HEAD --strategy-option octopus - - - name: Add modified files - run: | - cd dst-repo - git add . - - - name: Create Pull Request - uses: peter-evans/create-pull-request@v4 - with: - path: dst-repo - token: ${{ secrets.GHP_TOKEN }} - body: | - **Auto-generated PR** - Repo - [${{ env.REPO }}](https://github.com/${{ env.REPO }}) - PR name - ${{ env.PR_NAME }} - Commit - ${{ env.REPO }}@${{ env.COMMIT_SHA }} - Author - @${{ env.USERNAME }} - branch: AutoPR/${{ env.SHORT_SHA }} - assignees: ${{ env.USERNAME }} - title: ${{ env.PR_NAME }} - labels: AutoPR - author: ${{ env.USERNAME }} <${{ env.USER_EMAIL }}> diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index a168af277fb8..26f3819dd2bf 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -16,7 +16,7 @@ concurrency: jobs: # formatting and basic install on cpu-only machine - formatting: + unit-tests: runs-on: ubuntu-20.04 steps: diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 279bad471c01..59770a5e23b3 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -17,7 +17,7 @@ concurrency: cancel-in-progress: true jobs: - version-check: + unit-tests: strategy: matrix: pyVersion: ["3.6", "3.7", "3.8", "3.9", "3.10"] From faa00b1373e2e5628c660d0f40cab485cd960c33 Mon Sep 17 00:00:00 2001 From: baodi Date: Sat, 16 Dec 2023 07:16:33 +0800 Subject: [PATCH 08/18] fix falcon model load from_config meta_data error (#4783) --- deepspeed/module_inject/auto_tp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 6f545d4cb13b..b9a726bec67f 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -121,7 +121,9 @@ class Loading(): def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] - load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm"] + load_layer_names = [ + "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear" + ] return module.__class__ in load_layers or module._get_name() in load_layer_names def load_buffer(module, state_dict, prefix): From 449e454f83bb6a14b0de359660d4b206d5c3feed Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Sat, 16 Dec 2023 01:17:38 +0200 Subject: [PATCH 09/18] mv DeepSpeedEngine param_names dict init post _configure_distributed_model (#4803) In some backends, when params are being moved from host to device, they might changed their python object id(), which uses a the key in the param_names dictionary. in such case this dict might become invalid. Co-authored-by: Michael Wyatt --- deepspeed/runtime/engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 44b44c79ba55..c5f4d3e6530d 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -232,9 +232,6 @@ def __init__( # for debug purposes - can then debug print: debug_get_module_name(module) debug_extract_module_and_param_names(model) - # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict - self.param_names = {param: name for name, param in model.named_parameters()} - self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() @@ -261,6 +258,9 @@ def __init__( # Configure distributed model self._configure_distributed_model(model) + # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict + self.param_names = {param: name for name, param in model.named_parameters()} + self._get_model_parameters() see_memory_usage(f"DeepSpeed Engine: After configure distributed model") From d37fc25d568bbcfb36772d9eae4539d8deac1bd2 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 15 Dec 2023 15:31:46 -0800 Subject: [PATCH 10/18] Refactor launcher user arg parsing (#4824) Splitting work from #4769 because we are still debugging transformers integration issues. Parsing was broken for user arguments (see #4795). Additionally, parsing of user arguments is tricky and there are lots of edge cases. For example: #4660, #4716, #3967. I've attempted to accommodate all of the possible types of string inputs and added unit tests. --- deepspeed/launcher/multinode_runner.py | 14 ++++-- deepspeed/launcher/runner.py | 4 -- tests/unit/launcher/test_user_args.py | 64 ++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 tests/unit/launcher/test_user_args.py diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 730146f5bcd2..07c70f3cc6c5 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -56,13 +56,21 @@ def __init__(self, args, world_info_base64): def backend_exists(self): return shutil.which('pdsh') + def parse_user_args(self): + processed_args = [] + for arg in self.args.user_args: + # With pdsh, if we are passing a string as an argument, it will get + # split on whitespace. To avoid this and support strings that + # contain '"', we do this extra processing step: + if " " in arg: + arg = '"{}"'.format(arg.replace('"', '\\"')) + processed_args.append(arg) + return processed_args + @property def name(self): return "pdsh" - def parse_user_args(self): - return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args)) - def get_cmd(self, environment, active_resources): environment['PDSH_RCMD_TYPE'] = 'ssh' if self.args.ssh_port is not None: # only specify ssh port if it is specified diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index a7fa2b5053e5..99ebc9771e41 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -12,7 +12,6 @@ import os import re import sys -import shlex import json import base64 import argparse @@ -389,9 +388,6 @@ def parse_num_nodes(str_num_nodes: str, elastic_training: bool): def main(args=None): args = parse_args(args) - # For when argparse interprets remaining args as a single string - args.user_args = shlex.split(" ".join(list(map(lambda x: x if x.startswith("-") else f'"{x}"', args.user_args)))) - if args.elastic_training: assert args.master_addr != "", "Master Addr is required when elastic training is enabled" diff --git a/tests/unit/launcher/test_user_args.py b/tests/unit/launcher/test_user_args.py new file mode 100644 index 000000000000..99afd0f2cfa7 --- /dev/null +++ b/tests/unit/launcher/test_user_args.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import subprocess + +from deepspeed.accelerator import get_accelerator + +if not get_accelerator().is_available(): + pytest.skip("only supported in accelerator environments.", allow_module_level=True) + +user_arg_test_script = """import argparse +parser = argparse.ArgumentParser() +parser.add_argument("--prompt", type=str) +parser.add_argument("--local_rank", type=int, default=0) +parser.add_argument("--world_size", type=int, default=1) +args = parser.parse_args() +print("ARG PARSE SUCCESS") +""" + + +@pytest.fixture(scope="function") +def user_script_fp(tmpdir): + script_fp = tmpdir.join("user_arg_test.py") + with open(script_fp, "w") as f: + f.write(user_arg_test_script) + return script_fp + + +@pytest.fixture(scope="function") +def cmd(user_script_fp, prompt, multi_node): + if multi_node: + cmd = ("deepspeed", "--force_multi", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt) + else: + cmd = ("deepspeed", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt) + return cmd + + +@pytest.mark.parametrize("prompt", [ + '''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""", + '''I'm going to tell them "DeepSpeed is the best"''' +]) +@pytest.mark.parametrize("multi_node", [True, False]) +def test_user_args(cmd): + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}" + + +def test_bash_string_args(tmpdir, user_script_fp): + bash_script = f""" + ARGS="--prompt 'DeepSpeed is the best'" + echo ${{ARGS}}|xargs deepspeed --num_nodes 1 --num_gpus 1 {user_script_fp} + """ + + bash_fp = tmpdir.join("bash_script.sh") + with open(bash_fp, "w") as f: + f.write(bash_script) + + p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}" From 65b7727758a4c0ee08597a88ab4f051abcfc2a8a Mon Sep 17 00:00:00 2001 From: Alienfeel Date: Sat, 16 Dec 2023 07:35:00 +0800 Subject: [PATCH 11/18] Fix 4649 (#4650) Co-authored-by: Michael Wyatt --- deepspeed/launcher/runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 99ebc9771e41..4f45e1831b48 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -443,7 +443,11 @@ def main(args=None): if not args.master_addr: assert multi_node_exec first_host = list(active_resources.keys())[0] - hostname_cmd = [f"ssh {first_host} hostname -I"] + ssh_check_cmd = "ssh " + if args.ssh_port is not None: + ssh_check_cmd += f" -p {args.ssh_port}" + ssh_check_cmd += f" {first_host} hostname -I" + hostname_cmd = [ssh_check_cmd] try: result = subprocess.check_output(hostname_cmd, shell=True) except subprocess.CalledProcessError as err: From 4d866bd55a6b2b924987603b599c1f8f35911c4b Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 15 Dec 2023 17:02:38 -0800 Subject: [PATCH 12/18] Update version.txt after 0.12.5 release (#4826) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.12.5 Author - @mrwyattii Co-authored-by: mrwyattii --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 43c2417ca0c6..dabff2f13810 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.12.5 +0.12.6 From 4559dadd367453befd6c8f7d0049f8e900c897c3 Mon Sep 17 00:00:00 2001 From: BacharL Date: Mon, 18 Dec 2023 20:10:55 +0200 Subject: [PATCH 13/18] Cache metadata for TP activations and grads (#4360) PartitionedTensor.from_meta will cause device to host synchronization when reading the meta tensor in meta = meta.tolist() Added cpu cache for the meta tensor to avoid this synchronization in every activation and grad communication between the ranks. The meta tensor is assumed to be static since activation shape must be static. The user must call reset_activation_shape if any of the dimentions change. Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/pipe/engine.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index c8d6a0bff444..27fa5b69d35d 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -182,6 +182,10 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self.first_output_send = True self.first_gradient_send = True + self.pipe_partition_input_meta_cache = None + self.pipe_partition_output_meta_cache = None + self.pipe_partition_grad_meta_cache = None + self.grad_partition_grad_layer_meta_cache = None #stores the loss for the current micro batch being processed self.loss = torch.tensor(0.0).to(self.device) @@ -309,6 +313,11 @@ def reset_activation_shape(self): self.grad_layer = None self.meta_buffer = None + self.pipe_partition_input_meta_cache = None + self.pipe_partition_output_meta_cache = None + self.pipe_partition_grad_meta_cache = None + self.grad_partition_grad_layer_meta_cache = None + def train_batch(self, data_iter=None): """Progress the pipeline to train the next batch of data. The engine will ingest ``self.train_batch_size()`` total samples collectively across all workers. @@ -641,7 +650,9 @@ def _exec_forward_pass(self, buffer_id): # collect the partitioned input from the previous stage if self.is_pipe_partitioned and not self.is_first_stage(): - part_input = PartitionedTensor.from_meta(meta=inputs[0], + if self.pipe_partition_input_meta_cache is None: + self.pipe_partition_input_meta_cache = inputs[0].to('cpu') + part_input = PartitionedTensor.from_meta(meta=self.pipe_partition_input_meta_cache, local_part=inputs[1], group=self.grid.get_slice_parallel_group()) @@ -732,7 +743,9 @@ def _exec_backward_pass(self, buffer_id): # careful to also restore the computational graph of the tensors we partitioned. if self.is_pipe_partitioned: if self.is_grad_partitioned: - part_output = PartitionedTensor.from_meta(meta=outputs[0], + if self.pipe_partition_output_meta_cache is None: + self.pipe_partition_output_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_output_meta_cache, local_part=outputs[1], group=self.grid.get_slice_parallel_group()) self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full() @@ -745,7 +758,9 @@ def _exec_backward_pass(self, buffer_id): grad_tensors = self.grad_layer if self.is_grad_partitioned: #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') - part_grad = PartitionedTensor.from_meta(meta=self.grad_layer[0], + if self.grad_partition_grad_layer_meta_cache is None: + self.grad_partition_grad_layer_meta_cache = self.grad_layer[0].to('cpu') + part_grad = PartitionedTensor.from_meta(meta=self.grad_partition_grad_layer_meta_cache, local_part=self.grad_layer[1], group=self.grid.get_slice_parallel_group()) grad_tensors = (part_grad.full(), *grad_tensors[2:]) @@ -1088,7 +1103,9 @@ def _exec_recv_grads(self, buffer_id): # XXX these shapes are hardcoded for Megatron # Restore partitioned output if it was partitioned and we are sending full gradients if self.is_pipe_partitioned and not self.is_grad_partitioned: - part_output = PartitionedTensor.from_meta(meta=outputs[0], + if self.pipe_partition_grad_meta_cache is None: + self.pipe_partition_grad_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_grad_meta_cache, local_part=outputs[1], group=self.grid.get_slice_parallel_group()) outputs[0].data = part_output.full() From 4c2cac03402b64901dc4569d85190c54e0c89a28 Mon Sep 17 00:00:00 2001 From: Omar Elayan <142979319+oelayan7@users.noreply.github.com> Date: Mon, 18 Dec 2023 20:29:04 +0200 Subject: [PATCH 14/18] Inference changes for incorporating meta loading checkpoint (#4692) 1. In both files, the same logic was done that if when it is meta no need to move the tensors to the device. 2. Deletion of an unused member of the class --------- Co-authored-by: Olatunji Ruwase --- deepspeed/inference/engine.py | 12 +++++------ deepspeed/module_inject/auto_tp.py | 34 ++++++++++++++++++++---------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 149d20cd9305..71330b982ee0 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -109,11 +109,6 @@ def __init__(self, model, config): assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ "If you want to use cuda graph, please upgrade torch to at least v1.10" - # Check if model passed to engine is loaded w/ meta tensors, in which case - # kernel injection must be enabled. - # NOTE: This check assumes a Hugging Face hierarchy for the device type i.e. module.device.type - self.model_meta_device = self.module.device.type == 'meta' if hasattr(self.module, "device") else False - # convert model to intended dtype if config.dtype: self._convert_to_dtype(config) @@ -170,7 +165,12 @@ def __init__(self, model, config): self._apply_injection_policy(config, client_module) device = get_accelerator().current_device_name() - self.module.to(device) + # NOTE: This check assumes a Hugging Face hierarchy for the device type i.e. module.device.type + is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta' + if is_meta_device: + self.module.to_empty(device=device) + else: + self.module.to(device) if config.tensor_parallel.tp_size > 1: _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name()) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index b9a726bec67f..af0566cbb3cb 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -17,6 +17,16 @@ from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list +def move(tensor, device): + if tensor.is_meta: + return torch.empty_like(tensor, device=device) + else: + # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). + # Using copy=True instead of clone() will help in case of cpu --> cpu. + # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. + return tensor.to(device, copy=True) + + class ReplaceWithTensorSlicing: def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0): @@ -318,7 +328,7 @@ def _replace(self, child, name, conv_linear_layer): data = child.weight.data.split(get_shard_size_list( weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size), dim=1) - data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach() + data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() del data setattr(child, "replaced", True) @@ -326,9 +336,10 @@ def _replace(self, child, name, conv_linear_layer): return LmHeadLinearAllreduce( torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), child.bias if child.bias is None else torch.nn.parameter.Parameter( - child.bias.to(get_accelerator().current_device_name())), self.mp_group) + move(child.bias, + get_accelerator().current_device_name())), self.mp_group) return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ - torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group) + torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group) else: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] @@ -340,30 +351,31 @@ def _replace(self, child, name, conv_linear_layer): #for detecting fused type module_str = str(self.module).strip() #The copy is a regular copy, The shape of dst and src is the same - data_dc = prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index) + data_dc = move( + prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index), + get_accelerator().current_device_name()) - bias_data_dc = None if child.bias is None else prepare_tp_fused_qkvw( - module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to( - get_accelerator().current_device_name()) + bias_data_dc = None if child.bias is None else move( + prepare_tp_fused_qkvw(module_str, child.bias.data, self.mp_size, mp_replace.gpu_index), + get_accelerator().current_device_name()) else: data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size), dim=1 if self.conv_linear_layer else 0) - data_dc = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()).clone().detach() + data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() del data if child.bias is not None: bias_data = child.bias.data.split(get_shard_size_list( weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size), dim=0) - bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) + bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name()) bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) del bias_data else: bias_data_dc = None setattr(child, "replaced", True) - return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc.to(get_accelerator().current_device_name()), requires_grad=False), \ - bias=bias_data_dc) + return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc) def _slice_embedding(self, child, name, conv_linear_layer): if getattr(child, "replaced", False) == True: From 83fa673aaec23879248861a61028b775498a67a5 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Mon, 18 Dec 2023 10:32:19 -0800 Subject: [PATCH 15/18] Update CODEOWNERS (#4838) --- CODEOWNERS | 62 +++++++++++++++++++++++++++--------------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 2410b3ebc09b..2c16aef39a1b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -7,50 +7,50 @@ # top-level repo folders -/.github/ @jeffra @mrwyattii @loadams -/azure/ @jeffra @awan-10 -/benchmarks/ @jeffra @awan-10 @mrwyattii @molly-smith -/bin/ @jeffra -/csrc/ @RezaYazdaniAminabadi @awan-10 @jeffra @cmikeh2 @arashb -/deepspeed/ @jeffra -/docker/ @jeffra @awan-10 -/docs/ @jeffra @mrwyattii -/examples/ @jeffra @awan-10 @mrwyattii -/op_builder/ @jeffra @RezaYazdaniAminabadi @cmikeh2 -/release/ @jeffra @mrwyattii -/requirements/ @jeffra @mrwyattii -/scripts/ @jeffra @awan-10 -/tests/ @jeffra @mrwyattii @tjruwase +/.github/ @mrwyattii @loadams +/azure/ @mrwyattii @awan-10 +/benchmarks/ @awan-10 @mrwyattii +/bin/ @mrwyattii +/csrc/ @awan-10 @mrwyattii @cmikeh2 @arashb +/deepspeed/ @mrwyattii +/docker/ @mrwyattii @awan-10 +/docs/ @mrwyattii +/examples/ @awan-10 @mrwyattii +/op_builder/ @mrwyattii @cmikeh2 +/release/ @loadams @mrwyattii +/requirements/ @loadams @mrwyattii +/scripts/ @mrwyattii @awan-10 +/tests/ @mrwyattii @tjruwase @loadams # deepspeed -/deepspeed/autotuning/ @cli99 +/deepspeed/autotuning/ @mrwyattii /deepspeed/checkpoint/ @tjruwase /deepspeed/comm/ @awan-10 -/deepspeed/compression/ @yaozhewei @minjiaz @xiaoxiawu-microsoft @conglongli -/deepspeed/elasticity/ @jeffra @awan-10 -/deepspeed/launcher/ @jeffra @awan-10 -/deepspeed/module_inject/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb +/deepspeed/compression/ @minjiaz @xiaoxiawu-microsoft @conglongli +/deepspeed/elasticity/ @mrwyattii @awan-10 +/deepspeed/launcher/ @mrwyattii @awan-10 +/deepspeed/module_inject/ @mrwyattii @awan-10 @cmikeh2 @arashb /deepspeed/moe/ @awan-10 -/deepspeed/monitor/ @awan-10 @jeffra -/deepspeed/nebula/ @tjruwase @jeffra -/deepspeed/ops/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb +/deepspeed/monitor/ @awan-10 @mrwyattii +/deepspeed/nebula/ @tjruwase @mrwyattii +/deepspeed/ops/ @mrwyattii @awan-10 @cmikeh2 @arashb /deepspeed/pipe/ @ShadenSmith @duli2012 -/deepspeed/profiling/ @cli99 -/deepspeed/utils/ @jeffra @tjruwase @awan-10 +/deepspeed/profiling/ @ShijieZZZZ +/deepspeed/utils/ @mrwyattii @tjruwase @awan-10 # inference -/deepspeed/inference/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb -/deepspeed/model_implementations/ @RezaYazdaniAminabadi @jeffra @mrwyattii @awan-10 @cmikeh2 @arashb +/deepspeed/inference/ @mrwyattii @awan-10 @cmikeh2 @arashb +/deepspeed/model_implementations/ @mrwyattii @awan-10 @cmikeh2 @arashb # training -/deepspeed/runtime/ @jeffra @tjruwase -/deepspeed/runtime/activation_checkpointing/ @jeffra @tjruwase -/deepspeed/runtime/checkpoint_engine/ @tjruwase @jeffra +/deepspeed/runtime/ @mrwyattii @tjruwase +/deepspeed/runtime/activation_checkpointing/ @mrwyattii @tjruwase +/deepspeed/runtime/checkpoint_engine/ @tjruwase @mrwyattii /deepspeed/runtime/comm/ @awan-10 /deepspeed/runtime/compression/ @awan-10 @conglongli /deepspeed/runtime/data_pipeline/ @conglongli -/deepspeed/runtime/fp16/ @jeffra @tjruwase +/deepspeed/runtime/fp16/ @mrwyattii @tjruwase /deepspeed/runtime/fp16/onebit/ @conglongli @awan-10 /deepspeed/runtime/pipe/ @ShadenSmith @duli2012 /deepspeed/runtime/swap_tensor/ @tjruwase @mrwyattii -/deepspeed/runtime/zero/ @jeffra @tjruwase @samyam @mrwyattii +/deepspeed/runtime/zero/ @tjruwase @mrwyattii From c20f6fa4e0a7772400fc6b52999a2a9c3263cec0 Mon Sep 17 00:00:00 2001 From: baodi Date: Tue, 19 Dec 2023 02:45:01 +0800 Subject: [PATCH 16/18] support baichuan model: (#4721) * fix Baichuan meta data error * add BaichuanLayer and DecoderLayer to glmtype when prepare tp fused qkvw * add get_alibi_mask function for Baichuan to enable TP --------- Co-authored-by: Lai, Yejing Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase Co-authored-by: Michael Wyatt --- deepspeed/inference/engine.py | 6 +++++- deepspeed/module_inject/auto_tp_model_utils.py | 10 ++++++++++ deepspeed/module_inject/fusedqkv_utils.py | 4 +++- deepspeed/module_inject/replace_module.py | 7 +++++-- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 71330b982ee0..5fb5540dafa7 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -26,7 +26,7 @@ from ..module_inject.auto_tp import AutoTP from ..module_inject.replace_policy import generic_policies -from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor +from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor, get_alibi_mask from ..ops.transformer.inference.ds_attention import DeepSpeedSelfAttention from ..model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference @@ -220,6 +220,10 @@ def build_alibi_tensor(self): if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'): self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor + if hasattr(self.module, 'model'): + if hasattr(self.module.model, 'get_alibi_mask'): + self.module.model.get_alibi_mask_orig = self.module.model.get_alibi_mask + self.module.model.__class__.get_alibi_mask = get_alibi_mask def build_attn_bias(self): if hasattr(self.module, 'transformer'): diff --git a/deepspeed/module_inject/auto_tp_model_utils.py b/deepspeed/module_inject/auto_tp_model_utils.py index 51e52e3258dd..a71b1a54d6f6 100644 --- a/deepspeed/module_inject/auto_tp_model_utils.py +++ b/deepspeed/module_inject/auto_tp_model_utils.py @@ -61,6 +61,16 @@ def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) +def get_alibi_mask(self, tensor, seq_length_with_past): + mask = self.get_alibi_mask_orig(tensor, seq_length_with_past) + if not self.training and dist.is_initialized(): + num_heads_per_rank = get_shard_size(self.n_head, dist.get_world_size()) + offset = sum(get_shard_size_list(self.n_head, dist.get_world_size())[0:dist.get_rank()]) + mask = mask[offset:num_heads_per_rank + offset, :seq_length_with_past, :seq_length_with_past] + + return mask + + def build_mpt_atten_bias_tensor(self, device, dtype, diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index 8616b3505488..d61e78ab8d0e 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -17,7 +17,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): def require_tp_fused_qkvw(name, mp_size): - fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv'] + fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack'] if mp_size == 1: return False @@ -36,6 +36,8 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): 'GLMBlock': 'glmtype', "MPTBlock": 'glmtype', "MptBlock": 'glmtype', + "BaichuanLayer": 'glmtype', + "DecoderLayer": 'glmtype', } def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e3cc64c4f37e..5b7d2209d89e 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -314,10 +314,13 @@ def set_lm_head(module): module.lm_head, "weight") and module.lm_head.weight.is_meta: module.lm_head.weight = embedding_weight # enable tensor parallel for the last linear - if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and not module.lm_head.weight.is_meta: + if hasattr(module, "lm_head") and hasattr(module.lm_head, + "weight") and not module.lm_head.weight.is_meta and isinstance( + module.lm_head, torch.nn.Linear): module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head") elif hasattr(module, "embed_out") and hasattr(module.embed_out, - "weight") and not module.embed_out.weight.is_meta: + "weight") and not module.embed_out.weight.is_meta and isinstance( + module.embed_out, torch.nn.Linear): module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") return module From 3f1c3c2ad3fe89651327aec8703b696331bdef47 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Mon, 18 Dec 2023 22:57:30 +0200 Subject: [PATCH 17/18] inference engine: check if accelerator supports FP16 (#4832) Co-authored-by: Michael Wyatt --- deepspeed/inference/engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 5fb5540dafa7..ccecc8376ad6 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -71,6 +71,9 @@ def __init__(self, model, config): if hasattr(self.module, "config"): TransformerPolicy.hf_model_config = self.module.config + if config.dtype == torch.half and not get_accelerator().is_fp16_supported(): + raise ValueError("Type fp16 is not supported.") + # todo: keep this self.injection_dict because we don't use to change config.injection_policy API # todo: this will get changed when Molly's PR on auto injection dict is merged self.injection_dict = config.injection_policy From a00bdde86a99e39f12699ecb25657724973a9cda Mon Sep 17 00:00:00 2001 From: Gavin Goodship Date: Mon, 18 Dec 2023 21:17:50 +0000 Subject: [PATCH 18/18] Update zeropp.md (#4835) Doc corrections --------- Co-authored-by: Michael Wyatt Co-authored-by: Michael Wyatt --- docs/_tutorials/zeropp.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/_tutorials/zeropp.md b/docs/_tutorials/zeropp.md index 866bb9389e22..32080d1a16ad 100644 --- a/docs/_tutorials/zeropp.md +++ b/docs/_tutorials/zeropp.md @@ -10,8 +10,8 @@ We recommend that you read the tutorials on [Getting Started](/getting-started/) ## Three Components of ZeRO++ ZeRO++ consists of three key designs, namely quantized weights (*qwZ*), hiearchical partitioning ZeRO (*hpZ*), and quantized gradients (*qgZ*): - - *qwZ* applies block-based quantization to reduce ZeRO parameter all-gather communication volume by half from FP16 to INT8) - - *hpZ* eliminates inter-node backward parameter all-gather communication through data remapping and recomputation + - *qwZ* applies block-based quantization to reduce ZeRO parameter all-gather communication volume by half from FP16 to INT8. + - *hpZ* eliminates inter-node backward parameter all-gather communication through data remapping and recomputation. - *qgZ* replaces gradients allreduce collective with a new communication efficient all-to-all based quantized gradient averaging. Collectively, the three optimization reduces communication volume by 4x compared to ZeRO baseline. Each of the three components can be enabled independent of each other and collectively as a group as described in the next section. @@ -24,9 +24,9 @@ For this tutorial, we will configure a 18 billion parameter GPT-2 model using th ## Training a 18B parameter GPT-2 with ZeRO++ There are no change needed to the user code. However, since ZeRO++ extends ZeRO Stage 3 (ZeRO-3), appropriate flags need to be added to activate each or all of the three ZeRO++ communication collective optimizations. The three flags and their meanings and defaults and preferred values: - - zero_quantized_weights: Boolean indicating whether to use quantized zero weights (*qwZ*), default is false - - zero_hpz_partition_size: number of ranks in *hpZ* (secondary partition) group, default is 1 meaning no hpZ, ideal is number of ranks (gpus) per node - - zero_quantized_gradients: Boolean indicating whether to use quantized zero gradients (*qgZ*), default is false + - zero_quantized_weights: Boolean indicating whether to use quantized zero weights (*qwZ*), default is false. + - zero_hpz_partition_size: number of ranks in *hpZ* (secondary partition) group, default is 1 meaning no hpZ, ideal is number of ranks (gpus) per node. + - zero_quantized_gradients: Boolean indicating whether to use quantized zero gradients (*qgZ*), default is false. ### DeepSpeed Configuration Changes