Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] add performance benchmark of shardformer #4175

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,49 @@ pytest tests/test_shardformer

### System Performance

To be added.
We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.

We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.

In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |


<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus2.png" width="600" />
<br/>
</p>

In the case of using 4 GPUs, the training times are as follows.

| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |



<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus4.png" width="600" />
<br/>
</p>


As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.

### Convergence

To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.

To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.

| accuracy | f1 | loss | GPU number | model shard |
| :------: | :-----: | :-----: | :--------: | :---------: |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
--max_epochs 1 \
Expand Down
86 changes: 86 additions & 0 deletions colossalai/shardformer/examples/performance_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Shardformer Benchmark
"""
import torch
import torch.distributed as dist
import transformers
import triton

import colossalai
from colossalai.shardformer import ShardConfig, ShardFormer


def data_gen(batch_size, seq_length):
input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
return dict(input_ids=input_ids, attention_mask=attention_mask)


def data_gen_for_sequence_classification(batch_size, seq_length):
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen(batch_size, seq_length)
data['labels'] = torch.ones((batch_size), dtype=torch.long)
return data


MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)

# vary seq length for fixed head and batch=4
configs = [
triton.testing.Benchmark(x_names=['N_CTX'],
x_vals=[2**i for i in range(8, 13)],
line_arg='provider',
line_vals=['org_model', 'shard_model'],
line_names=['org_model', 'shard_model'],
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
args={
'BATCH': BATCH,
'dtype': torch.float16,
'model_func': model_func
})
]


def train(model, data):
output = model(**data)
loss = output.logits.mean()
loss.backward()


@triton.testing.perf_report(configs)
def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
warmup = 10
rep = 100
# prepare data
data = data_gen_for_sequence_classification(BATCH, N_CTX)
data = {k: v.cuda() for k, v in data.items()}
model = model_func().to(device)
model.train()
if provider == "org_model":
fn = lambda: train(model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "shard_model":
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model).cuda()
fn = lambda: train(sharded_model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
colossalai.launch_from_torch({})
bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)
Loading