Skip to content

Commit

Permalink
Merge branch 'master' into lekurile/update_ds_chat_ci
Browse files Browse the repository at this point in the history
  • Loading branch information
lekurile authored Nov 7, 2023
2 parents 5fe3f39 + cbec96b commit 4f6be9e
Show file tree
Hide file tree
Showing 16 changed files with 438 additions and 19 deletions.
18 changes: 14 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@
## Latest News
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp)
* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)
* [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)]
* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md)]
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20x faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)

<!-- NOTE: we must use html for news items otherwise links will be broken in the 'more news' section -->
<details>
<summary>More news</summary>
<ul>
<li>[2023/08] <a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md")>DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements</a></li>

<li>[2023/08] <a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses">DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models</a> [<a href="https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md">中文</a>] [<a href="https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md">日本語</a>]</li>

<li>[2023/06] <a href="https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/">ZeRO++: A leap in speed for LLM and chat model training with 4X less communication</a> [<a href="https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/">English</a>] [<a href="https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md">中文</a>] [<a href="https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md">日本語</a>]</li>
</ul>
</details>

---

Expand Down
52 changes: 52 additions & 0 deletions blogs/deepspeed-offloadpp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow

Deep learning has been successfully adopted in a wide range of applications such as speech recognition, chatbot, text and image generation, etc. To achieve better model serving accuracy, model size grows significantly. Take language models as example, from BERT with 110 million parameters to Megatron-Turing NLG with 530 billion parameters, the model size grows almost 5000x. Given limited GPU memory size, we need to efficiently utilize GPU memory to achieve good system throughput.

ZeRO offers memory efficient data parallel training scheme. For training large models like LLMs using ZeRO, GPU memory size is still often insufficient to hold all the model parameters. Thus, ZeRO-Offload is introduced to solve this insufficient GPU memory issue. ZeRO-Offload releases GPU memory pressure by offloading data and compute to the CPU side while minimizing CPU-GPU data copy overhead. Given CPU memory is often orders-of-magnitude larger than GPU memory, ZeRO-Offload was the first piece of work that enables billion-level parameter training even with very limited GPU memory resources (e.g., to an extreme: single GPU). ZeRO-Offload provides excellent performance when model size is multiple times larger than total GPU memory size.

However, system efficiency is still far from optimal when adopting ZeRO-Offload in some scenarios. Especially in the cases like small batch training, model that could not fit into GPU memory but not orders-of-magnitude bigger than GPU memory capacity, CPU offload not only introduce long end-to-end latency, but also underutilize GPU computation resources. To reduce memory copy latency as well as inefficient utilization of GPU introduced in these offload cases, we propose ZeRO-Offload++, which leverages both CPU and GPU coherently. ZeRO-Offload++ mainly includes 3 new features as _Twin-Flow_, MemCpy reduction, CPUAdam optimization. Now we release our __Twin-Flow__ feature.

The key benefits are:
* With _Twin-Flow_, ZeRO-Offload++ achieves up to **6x** training speedup compared with ZeRO-Offload.
* High-level API provided in DeepSpeed config JSON makes it easy to use and fine-tune.

![h100-img](./images/h100-8.png)

## Twin-Flow

In DeepSpeed, when training using popular optimizer like Adam, optimizer offloading follows an all-or-nothing policy. For simplifed example shown as Figure below, without offloading, all the parameters will be updated using GPU adam as FusedAdam optimizer. On the other hand, if offloading is enabled, all model weights use CPUAdam to update.

![cpu-offload-img](./images/cpu-offload.png)

The major downside of this all-or-nothing offloading is, when offload all optimizer states to CPU side, both GPU memory and compute resources remain under-utilized. Although increasing batch size improves GPU utilization rate, each training iteration time is still super long compared with no-offloading case. To improve GPU compute and memory utilization rate as well as decrease training iteration time, we introduce a new feature in our DeepSpeed training engine called _Twin-Flow_.

In comparison, _Twin-Flow_ allows a portion of optimizer states to be held in CPU memory and the other portion of optimizer states remaining in GPU memory. When optimization step is triggered, both CPU and GPU can do parameter updates simultaneously. Once offloading is enabled, we provide an offload ratio configuration which allows users to adjust how many percentages of model weights are updated on CPU side and the rest are happened on GPU side. "_Twin_" comes from the idea that both CPU and GPU are using the same optimizer function here. "_Flow_" means parameters are not only hold in both host and device memory, but also computed using both CPU and GPU cores.

As shown in Figure below, with ZeRO-Offload enabled and we set _Twin-Flow_ ratio of 0.4 (40%). DeepSpeed Training engine will automatically assign first 40% (i.e. 0-40%) of weights step procedure on the CPU side using CPUAdam, and use GPU side FusedAdam to update the rest 60% (i.e., 40-100%) model parameters jointly. Therefore, with _Twin-Flow_, we can achieve decent GPU memory and core utilization rate, at the same time reduce training iteation time in optimizer offloading cases.

![_Twin-Flow_-img](./images/twin-offload.png)

Note that this _Twin-Flow_ ratio can be adjusted based on how much GPU idle memory is available. The smaller this ratio is, the more GPU memory and cores are used and the shorter training iteration time it achieves. The ideal case is to be as near as GPU memory upper bound in order to minimize training iteration time.
Note that _Twin-Flow_ is not limited to Adam optimizer only, it can be applied to any optimizer (e.g., AdaGrad) from the user side.

## Performance Evaluation

We conduct our performance evaluations over both A100 and H100 DGX machine and test for OPT model with 13B and 30B parameters. We run 13B OPT model training on a 8 A100 DGX machine, and run OPT-30B model training using a 8 H100 DGX machine. With some tuning on offload ratio in ZeRO-Offload++, we achieve 6x and 3x training speedup of Meta OPT models on single DGX-H100-80GB and DGX-A100-40GB, respectively (top-most figure and bottom figure here).

![a100-img](./images/a100-8.png)

## On-going Optimizations

* Reduce uncessary D2H/H2D memcpy

* On-the-fly fp16 to fp32 casting for CPUAdam

## Tutorials

Examples and Tutorials are [here](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/examples_deepspeed/offload_pp/README.md)

## Contributors:

This project was made possible by the contributions of the following people from DeepSpeed Team:

[Guanhua Wang](https://www.microsoft.com/en-us/research/people/guanhuawang/), Masahiro Tanaka, Xiaoxia Wu, Lok Chand Koppaka, Samyam Rajbhandari, [Olatunji Ruwase](https://www.microsoft.com/en-us/research/people/olruwase/), [Yuxiong He](https://www.microsoft.com/en-us/research/people/yuxhe/) (team lead)
Binary file added blogs/deepspeed-offloadpp/images/a100-8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-offloadpp/images/cpu-offload.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-offloadpp/images/h100-8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-offloadpp/images/twin-offload.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion deepspeed/runtime/activation_checkpointing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"partitioned_activations": [true|false],
"number_checkpoints": 100,
"contiguous_memory_optimization": [true|false],
"cpu_checkpointing": [true|false]
"cpu_checkpointing": [true|false],
"profile": [true|false],
"synchronize_checkpoint_boundary": [true|false],
}
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,9 @@ def zero_cpu_offload(self):
return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu
return False

def zero_partial_offload(self):
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0)

def zero_sub_group_size(self):
return self._config.zero_config.sub_group_size

Expand Down Expand Up @@ -1565,6 +1568,7 @@ def _configure_zero_optimizer(self, optimizer):
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
offload_ratio=self.zero_partial_offload(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
Expand Down
9 changes: 8 additions & 1 deletion deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from typing import Optional
from enum import Enum
from deepspeed.pydantic_v1 import Field, validator
from deepspeed.pydantic_v1 import Field, validator, root_validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
Expand Down Expand Up @@ -300,3 +300,10 @@ def overlap_comm_valid(cls, field_value, values):
assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
field_value = values["stage"] == ZeroStageEnum.weights
return field_value

@root_validator
def offload_ratio_check(cls, values):
offload_config = getattr(values, "offload_optimizer", {})
if offload_config and offload_config.ratio < 1.0:
assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
return values
3 changes: 3 additions & 0 deletions deepspeed/runtime/zero/offload_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
def set_pipeline(cls, field_value, values):
values["pipeline"] = field_value or values.get("pipeline", False)
return field_value

ratio: float = Field(1.0, ge=0.0, le=1.0)
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
60 changes: 52 additions & 8 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper
from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
Expand Down Expand Up @@ -104,6 +104,7 @@ def __init__(
offload_optimizer_config=None,
offload_param_config=None,
sub_group_size=1000000000000,
offload_ratio=0.0,
mpu=None,
clip_grad=0.0,
gradient_accumulation_dtype=torch.float32,
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(
self.offload_param_pin_memory = False
self.params_in_nvme_and_cpu = False
self.max_params_in_cpu = 0
self.partial_offload = offload_ratio

#num of ranks in a ZeRO param partitioning group
self.zero_hpz_partition_size = zero_hpz_partition_size
Expand Down Expand Up @@ -191,6 +193,23 @@ def __init__(
self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)

# backup fused_adam optimizer init
if self.offload_optimizer and self.partial_offload != 1.0:
backup_gpu_tensor = torch.randn(1, device='cuda').to(self.dtype)
backup_gpu_param = torch.nn.Parameter(backup_gpu_tensor)
assert type(init_optimizer) == DeepSpeedCPUAdam, 'Hybrid Optimizer Only Supports DeepSpeedCPUAdam'
self.backup_optimizer = FusedAdam([backup_gpu_param],
lr=self.optimizer.param_groups[0]["lr"],
bias_correction=self.optimizer.param_groups[0]["bias_correction"],
betas=self.optimizer.param_groups[0]["betas"],
eps=self.optimizer.param_groups[0]["eps"],
weight_decay=self.optimizer.param_groups[0]["weight_decay"],
amsgrad=self.optimizer.param_groups[0]["amsgrad"])
# Multiple param_groups configs for back-up optimizer
if len(self.optimizer.param_groups) > 1:
for i in range(1, len(self.optimizer.param_groups)):
self.backup_optimizer.add_param_group(self.optimizer.param_groups[i])

self.module = module
self.elastic_checkpoint = elastic_checkpoint

Expand Down Expand Up @@ -780,6 +799,17 @@ def _create_fp32_partitions(self):
nvme_fp32_dest_tensors = []
fp32_element_size = torch.tensor([], dtype=torch.float32).element_size()

# Assign portion of subgroup to cpu, the other to gpu.
if self.offload_optimizer:
self.subgroup_to_device = {}
sub_group_size = len(self.fp16_partitioned_groups_flat)
# print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n")
for i in range(sub_group_size):
if i < int(self.partial_offload * sub_group_size):
self.subgroup_to_device[i] = 'cpu'
else:
self.subgroup_to_device[i] = get_accelerator()._name

for i, tensor in enumerate(self.fp16_partitioned_groups_flat):
num_elements = self.fp16_partitioned_groups_flat_numel[i]

Expand Down Expand Up @@ -816,8 +846,12 @@ def _create_fp32_partitions(self):
self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)
self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer)
else:
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
self.device).clone().float().detach())
if self.offload_optimizer:
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
self.subgroup_to_device[i]).clone().float().detach())
else:
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
self.device).clone().float().detach())

self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it

Expand Down Expand Up @@ -885,10 +919,20 @@ def _release_ipg_buffers(self):
def _optimizer_step(self, sub_group_id):
param_group_id = self.sub_group_to_group_id[sub_group_id]
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]

self.optimizer.step()
self.optimizer.param_groups[param_group_id]['params'] = []
if self.offload_optimizer:
cur_device = self.subgroup_to_device[sub_group_id]
if cur_device == 'cpu':
self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]
cpu_loss = self.optimizer.step()
self.optimizer.param_groups[param_group_id]['params'] = []
else:
self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param]
gpu_loss = self.backup_optimizer.step()
self.backup_optimizer.param_groups[param_group_id]['params'] = []
else:
self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]
self.optimizer.step()
self.optimizer.param_groups[param_group_id]['params'] = []

def _swappable_optimizer_subgroup(self, sub_group_id):
if not self.swap_optimizer:
Expand Down Expand Up @@ -955,7 +999,7 @@ def initialize_optimizer_states(self):
if self.offload_optimizer_pin_memory:
subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer)

self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer
self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer.to(self.subgroup_to_device[i])
else:
self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements)

Expand Down
Loading

0 comments on commit 4f6be9e

Please sign in to comment.