Skip to content

Commit

Permalink
[Shardformer] Merge flash attention branch to pipeline branch (hpcait…
Browse files Browse the repository at this point in the history
…ech#4362)

* [shardformer] supported flash attention test dependency (hpcaitech#4158)

* [shardformer] fix flash attention utils test (hpcaitech#4180)

* [shardformer] opt support flash attention (hpcaitech#4163)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] add performance benchmark of shardformer (hpcaitech#4175)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] benchmark fix

* [shardformer] benchmark fix

* [shardformer] llama support flash attention (hpcaitech#4185)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] llama support flash attention

* [shardformer] llama support flash attention

* [shardformer] Move the import statement for xformer outside the forward function.

* [shardformer] gpt2 support flash attention. (hpcaitech#4191)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] bloom support flash attention (hpcaitech#4188)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom suport flash attention

* [shardformer] add assert to sequence length

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] bert support flash attention. (hpcaitech#4206)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bert support flash attention

* [shardformer] t5 support flash attention. (hpcaitech#4216)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 support flash attention

* [shardformer] t5 support flash attention

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* [shardformer] support 'paddedcausal'  type of attention mask in Coloattention. (hpcaitech#4215)

* added padded causal attn mask type for ColoAttention

* [shardformer]t5 flash attention fix (hpcaitech#4239)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 flash attention fix

* [shardformer] update gpt2 to use coloattention. (hpcaitech#4234)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2

* [shardformer] update opt and llama to use coloattention. (hpcaitech#4226)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt

* [shardformer] shardformer support jit fused operator. (hpcaitech#4236)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add type hint to 'self' param of forward

* [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (hpcaitech#4290)

* Feature/vit support (hpcaitech#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (hpcaitech#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (hpcaitech#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (hpcaitech#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

---------

Co-authored-by: Kun Lin <[email protected]>
Co-authored-by: FoolPlayer <[email protected]>

* [shardformer] whisper support flash attention (hpcaitech#4301)

* Feature/vit support (hpcaitech#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (hpcaitech#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (hpcaitech#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (hpcaitech#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] whisper support flash attention

* [shardformer] whisper support flash attention

* [shardformer]whisper support jit operator

---------

Co-authored-by: Kun Lin <[email protected]>
Co-authored-by: FoolPlayer <[email protected]>

* [shardformer] sam support flash attention (hpcaitech#4316)

* Feature/vit support (hpcaitech#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (hpcaitech#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (hpcaitech#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (hpcaitech#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] sam support flash attention

---------

Co-authored-by: Kun Lin <[email protected]>
Co-authored-by: FoolPlayer <[email protected]>

* [shardformer] merge blip2/chatglm  (hpcaitech#4321)

* Feature/vit support (hpcaitech#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (hpcaitech#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (hpcaitech#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (hpcaitech#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (hpcaitech#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

---------

Co-authored-by: Kun Lin <[email protected]>
Co-authored-by: FoolPlayer <[email protected]>
Co-authored-by: klhhhhh <[email protected]>

* [shardformer] blip2 support flash attention and jit operator (hpcaitech#4325)

* Feature/vit support (hpcaitech#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (hpcaitech#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (hpcaitech#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (hpcaitech#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (hpcaitech#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

---------

Co-authored-by: Kun Lin <[email protected]>
Co-authored-by: FoolPlayer <[email protected]>
Co-authored-by: klhhhhh <[email protected]>

* [shardformer] chatglm support flash attention and jit operator (hpcaitech#4330)

* Feature/vit support (hpcaitech#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (hpcaitech#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (hpcaitech#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (hpcaitech#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (hpcaitech#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

---------

Co-authored-by: Kun Lin <[email protected]>
Co-authored-by: FoolPlayer <[email protected]>
Co-authored-by: klhhhhh <[email protected]>

* [shardformer] vit support flash attention and jit operator (hpcaitech#4334)

* Feature/vit support (hpcaitech#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (hpcaitech#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (hpcaitech#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (hpcaitech#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (hpcaitech#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] vit support flash attention and jit operator

* [shardformer] vit support flash attention and jit operator

---------

Co-authored-by: Kun Lin <[email protected]>
Co-authored-by: FoolPlayer <[email protected]>
Co-authored-by: klhhhhh <[email protected]>

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] fix conflict

* [pipeline] fix conflict

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* fix flash attention tests

* gemini ignore whisper

* fix vit

* fix xformers import handle

---------

Co-authored-by: Frank Lee <[email protected]>
Co-authored-by: Kun Lin <[email protected]>
Co-authored-by: FoolPlayer <[email protected]>
Co-authored-by: klhhhhh <[email protected]>
  • Loading branch information
5 people authored and ver217 committed Aug 15, 2023
1 parent 0e3a8da commit 4711174
Show file tree
Hide file tree
Showing 52 changed files with 2,061 additions and 1,556 deletions.
58 changes: 55 additions & 3 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

### Quick Start

The sample API usage is given below:
The sample API usage is given below(If you enable the use of flash attention, please install xformers.):

```python
from colossalai.shardformer import ShardConfig, Shard
Expand Down Expand Up @@ -106,6 +106,20 @@ We will follow this roadmap to develop Shardformer:
- [ ] Multi-modal
- [x] SAM
- [x] BLIP-2
- [ ] Flash Attention Support
- [ ] NLP
- [x] BERT
- [x] T5
- [x] LlaMa
- [x] GPT2
- [x] OPT
- [x] BLOOM
- [ ] GLM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J

## 💡 API Design

Expand Down Expand Up @@ -373,11 +387,49 @@ pytest tests/test_shardformer

### System Performance

To be added.
We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.

We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.

In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |


<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus2.png" width="600" />
<br/>
</p>

In the case of using 4 GPUs, the training times are as follows.

| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |



<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus4.png" width="600" />
<br/>
</p>


As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.

### Convergence

To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.

To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.

| accuracy | f1 | loss | GPU number | model shard |
| :------: | :-----: | :-----: | :--------: | :---------: |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
--max_epochs 1 \
Expand Down
86 changes: 86 additions & 0 deletions colossalai/shardformer/examples/performance_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Shardformer Benchmark
"""
import torch
import torch.distributed as dist
import transformers
import triton

import colossalai
from colossalai.shardformer import ShardConfig, ShardFormer


def data_gen(batch_size, seq_length):
input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
return dict(input_ids=input_ids, attention_mask=attention_mask)


def data_gen_for_sequence_classification(batch_size, seq_length):
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen(batch_size, seq_length)
data['labels'] = torch.ones((batch_size), dtype=torch.long)
return data


MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)

# vary seq length for fixed head and batch=4
configs = [
triton.testing.Benchmark(x_names=['N_CTX'],
x_vals=[2**i for i in range(8, 13)],
line_arg='provider',
line_vals=['org_model', 'shard_model'],
line_names=['org_model', 'shard_model'],
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
args={
'BATCH': BATCH,
'dtype': torch.float16,
'model_func': model_func
})
]


def train(model, data):
output = model(**data)
loss = output.logits.mean()
loss.backward()


@triton.testing.perf_report(configs)
def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
warmup = 10
rep = 100
# prepare data
data = data_gen_for_sequence_classification(BATCH, N_CTX)
data = {k: v.cuda() for k, v in data.items()}
model = model_func().to(device)
model.train()
if provider == "org_model":
fn = lambda: train(model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "shard_model":
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model).cuda()
fn = lambda: train(sharded_model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
colossalai.launch_from_torch({})
bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)
138 changes: 137 additions & 1 deletion colossalai/shardformer/modeling/bert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import warnings
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand Down Expand Up @@ -962,3 +963,138 @@ def bert_for_question_answering_forward(
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}


def get_bert_flash_attention_forward():

try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bert.modeling_bert import BertAttention

def forward(
self: BertAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

query_layer = self.transpose_for_scores(mixed_query_layer)

use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

final_attention_mask = None
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r

positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
final_attention_mask = relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
final_attention_mask = relative_position_scores_query + relative_position_scores_key

scale = 1 / math.sqrt(self.attention_head_size)
if attention_mask is not None:
if final_attention_mask != None:
final_attention_mask = final_attention_mask * scale + attention_mask
else:
final_attention_mask = attention_mask
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
tgt_len = key_layer.size()[2]
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)

query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()

context_layer = me_attention(query_layer,
key_layer,
value_layer,
attn_bias=final_attention_mask,
p=self.dropout.p,
scale=scale)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, None)

if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs

return forward


def get_jit_fused_bert_self_output_forward():

from transformers.models.bert.modeling_bert import BertSelfOutput

def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states

return forward


def get_jit_fused_bert_output_forward():

from transformers.models.bert.modeling_bert import BertOutput

def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states

return forward
60 changes: 60 additions & 0 deletions colossalai/shardformer/modeling/blip2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -58,3 +59,62 @@ def forward(
return outputs

return forward


def get_blip2_flash_attention_forward():

from transformers.models.blip_2.modeling_blip_2 import Blip2Attention

from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention

def forward(
self: Blip2Attention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]

attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.dropout.p,
scale=self.scale)
context_layer = attention(query_states, key_states, value_states)

output = self.projection(context_layer)
outputs = (output, None)

return outputs

return forward


def get_jit_fused_blip2_QFormer_self_output_forward():

from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput

def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states

return forward


def get_jit_fused_blip2_QFormer_output_forward():

from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput

def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states

return forward
Loading

0 comments on commit 4711174

Please sign in to comment.