Skip to content

Commit

Permalink
[doc] update sp doc (#6055)
Browse files Browse the repository at this point in the history
* update sp doc

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
flybird11111 and pre-commit-ci[bot] authored Sep 11, 2024
1 parent 13946c4 commit a35a078
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 0 deletions.
19 changes: 19 additions & 0 deletions docs/source/en/concepts/paradigms_of_parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ Related paper:
- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
- [Chimera: Efficiently Training Large-Scale Neural Networks with Bidirectional Pipelines](https://arxiv.org/abs/2107.06925)

### Sequence Parallelism
Sequence parallelism is a parallel strategy that partitions along the sequence dimension, making it an effective method for training long text sequences. Mature sequence parallelism methods include Megatron’s sequence parallelism, DeepSpeed-Ulysses sequence parallelism, and ring-attention sequence parallelism.

#### Megatron SP:
This sequence parallelism method is implemented on top of tensor parallelism. On each GPU in model parallelism, the samples are independent and replicated. For parts that cannot utilize tensor parallelism, such as non-linear operations like LayerNorm, the sample data can be split into multiple parts along the sequence dimension, with each GPU computing a portion of the data. Then, tensor parallelism is used for the linear parts like attention and MLP, where activations need to be aggregated. This approach further reduces activation memory usage when the model is partitioned. It is important to note that this sequence parallelism method can only be used in conjunction with tensor parallelism.

#### DeepSpeed-Ulysses:
In this sequence parallelism, samples are split along the sequence dimension and the all-to-all communication operation is used, allowing each GPU to receive the full sequence but only compute the non-overlapping subset of attention heads, thereby achieving sequence parallelism. This parallel method supports fully general attention, allowing both dense and sparse attention.
all-to-all is a full exchange operation, similar to a distributed transpose operation. Before attention computation, samples are split along the sequence dimension, so each device only has a sequence length of N/P. However, after using all-to-all, the shape of the qkv subparts becomes [N, d/p], ensuring the overall sequence is considered during attention computation.

#### Ring Attention:
Ring attention is conceptually similar to flash attention. Each GPU computes only a local attention, and finally, the attention blocks are reduced to calculate the total attention. In Ring Attention, the input sequence is split into multiple chunks along the sequence dimension, with each chunk handled by a different GPU or processor. Ring Attention employs a strategy called "ring communication," where kv sub-blocks are passed between GPUs through p2p communication for iterative computation, enabling multi-GPU training on ultra-long texts. In this strategy, each processor exchanges information only with its predecessor and successor, forming a ring network. This allows intermediate results to be efficiently transmitted between processors without global synchronization, reducing communication overhead.

Related paper:
[Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/pdf/2205.05198)
[DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://arxiv.org/abs/2309.14509)
[Ring Attention with Blockwise Transformers for Near-Infinite Context](https://arxiv.org/pdf/2310.01889)


## Optimizer-Level Parallel

Expand Down Expand Up @@ -122,3 +140,4 @@ Related paper:
- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818)
<!-- doc-test-command: echo -->
156 changes: 156 additions & 0 deletions docs/source/en/features/sequence_parallelism.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Sequence Parallelism

Author: Mingyan Jiang

**Prerequisite Tutorials**
- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
- [Booster API](../basics/booster_api.md)
- [Shardformer](../features/shardformer.md)
- [Booster plugin](../basics/booster_plugins.md)

**Example Code**
- [Using Sequence Parallelism Strategy](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/llama/benchmark.py)

**Related Papers**
[Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/pdf/2205.05198)
[DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://arxiv.org/abs/2309.14509)
[Ring Attention with Blockwise Transformers for Near-Infinite Context](https://arxiv.org/pdf/2310.01889)

## Quick Overview

In this tutorial, you will learn how to use sequence parallelism. In Colossal-AI, we have implemented several types of sequence parallelism, including TP+SP, DeepSpeed-Ulysses, and ring attention. Below, we will introduce how to use these different types of sequence parallelism.

## Table Of Content

In this tutorial, we will cover the use of three sequence parallelism strategies:

1. Using TP+SP;
2. Using DeepSpeed-Ulysses;
3. Using ring attention.


## Implementation in Colossal-AI

In Colossal-AI, sequence parallelism is implemented via the shardformer and can be invoked through the `HybridParallelPlugin` and `MoeHybridParallelPlugin` interfaces. For more information about the plugins, refer to the [plugin usage documentation](../basics/booster_plugins.md).

### Using Sequence Parallelism with HybridParallelPlugin

The `HybridParallelPlugin` supports three types of sequence parallelism: TP+SP, DeepSpeed-Ulysses, and ring attention. You can refer to the parallel techniques introduction [document](../concepts/paradigms_of_parallelism.md) for more details. An [example](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/llama/benchmark.py) of sequence parallelism with HybridParallelPlugin can be found here.

#### Defining Model Components

```python
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
import torch.distributed as dist
from colossalai.booster import Booster
config = LlamaConfig(max_position_embeddings=4096)
from colossalai.booster.plugin import HybridParallelPlugin

# define dataset
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
)
self.attention_mask = torch.ones_like(self.input_ids)

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}

parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
args = parser.parse_args()

model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
optimizer = HybridAdam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
# usually, num_samples=args.batch_size * args.num_steps * dp_size
dataset = RandomDataset(
num_samples=10000, max_length=args.max_length, vocab_size=config.vocab_size
)
```
### Using TP+SP
Define the plugin. When using this sequence parallelism, sp_size will be set to match tp_size, and the tp group will overlap with the sp group.
```python
plugin = HybridParallelPlugin(
tp_size=4,
sp_size=1,
enable_all_optimization=True,
enable_sequence_parallelism=True,
sequence_parallelism_mode="split_gather",
)
```

#### Using DeepSpeed-Ulysses
Define the plugin. In the DeepSpeed-Ulysses sequence parallelism, the tp group and sp group are orthogonal.
```python
plugin = HybridParallelPlugin(
tp_size=2,
sp_size=2,
enable_all_optimization=True,
enable_sequence_parallelism=True,
sequence_parallelism_mode="all_to_all",
)
```

#### Using Ring Attention
Define the plugin. In ring attention sequence parallelism, the tp group and sp group are orthogonal, and sp_size must be set to the correct parallel size.
```python
plugin = HybridParallelPlugin(
tp_size=2,
sp_size=2,
enable_all_optimization=True,
enable_sequence_parallelism=True,
sequence_parallelism_mode="ring_attn",
)
```
#### Using Booster
```python
booster = Booster(plugin=plugin)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
```

#### Training the Model
```python
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not dist.get_rank()==0)):
outputs = model(**batch)
loss = outputs[0]
del outputs # free memory

if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
```
### Sequence Parallelism with MoeHybridParallelPlugin
Currently, the `MoeHybridParallelPlugin` only supports DeepSpeed-Ulysses sequence parallelism. The usage is similar to HybridParallelPlugin. For specific examples, refer to this [example](https://github.com/hpcaitech/ColossalAI/blob/main/examples/language/deepseek/benchmark.py).



### Conclusion
Among the sequence parallelism methods mentioned, ring attention has no requirements for the number of attention heads and can train ultra-long sequences. However, due to the division of computation, its performance may decrease. TP+SP and DeepSpeed-Ulysses have requirements for the number of attention heads, which must be divisible by the sp group size. These sequence parallelism methods are all compatible with high-performance attention mechanisms like flash attention. Sequence parallelism can also be used with Gemini to train extremely large-scale models, and it can be combined with TP, PP, and DP to form 4D parallelism.

<!-- doc-test-command: torchrun --standalone --nproc_per_node=4 sequence_parallelism.py -->
20 changes: 20 additions & 0 deletions docs/source/zh-Hans/concepts/paradigms_of_parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@
- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
- [Chimera: Efficiently Training Large-Scale Neural Networks with Bidirectional Pipelines](https://arxiv.org/abs/2107.06925)

### 序列并行
序列并行是一种对于序列维度进行切分的并行策略,它是训练长文本序列的有效方法。现成熟的序列并行方法包括megatron提出的序列并行,DeepSpeed-Ulysses序列并行和ring-attention序列并行等。
#### megatron sp:

该序列并行方法是在张量并行的基础上实现的序列并行,模型并行的每个gpu上,样本独立且重复的,对于非线性运算的部分如layernorm等无法使用张量并行的模块,可以在序列维度将样本数据切分为多个部分,每个gpu计算部分数据,然后在计算attention及mlp等线性部分使用张量并行策略,需要将activation汇总,这样可以在模型进行切分的情况下进一步减少activation的内存占用,需要注意的是该序列并行方法只能与张量并行一起使用。

#### DeepSpeed-Ulysses:

序列并行通过在序列维度上分割样本并利用all-to-all通信操作,使每个GPU接收完整序列但仅计算注意力头的非重叠子集,从而实现序列并行。该并行方法具有完全通用的attention,可支持密集和稀疏的注意力。
alltoall是一个全交换操作,相当于分布式转置的操作,在attention计算之前,将样本沿序列维度进行切分,每个设备只有N/P的序列长度,然而使用alltoall后,qkv的子部分shape变为[N, d/p],在计算attention时仍考虑了整体的序列。
#### ring attention:

ring attention思路类似于flash attention,每个GPU只计算一个局部的attention,最后将所有的attention块结果进行归约计算出总的attention。在Ring Attention中,输入序列被沿着序列维度切分为多个块,每个块由不同的GPU或处理器负责处理,Ring Attention采用了一种称为“环形通信”的策略,通过跨卡的p2p通信相互传递kv子块来实现迭代计算,可以实现多卡的超长文本。在这种策略下,每个处理器只与它的前一个和后一个处理器交换信息,形成一个环形网络。通过这种方式,中间结果可以在处理器之间高效传递,而无需全局同步,减少了通信开销。

相关论文:
[Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/pdf/2205.05198)
[DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://arxiv.org/abs/2309.14509)
[Ring Attention with Blockwise Transformers for Near-Infinite Context](https://arxiv.org/pdf/2310.01889)


## 优化器相关的并行

Expand Down Expand Up @@ -90,3 +109,4 @@
- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818)
<!-- doc-test-command: echo -->
Loading

0 comments on commit a35a078

Please sign in to comment.