Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weight parallel #773

Open
wants to merge 23 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions ppdiffusers/deploy/sd3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compati
# 安装develop版本的paddle,请根据自己的cuda版本选择对应的paddle版本,这里选择12.3的cuda版本
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/

# 安装paddlemix库,使用集成在paddlemix库中的自定义算子。
python -m pip install paddlemix

# 指定 libCutlassGemmEpilogue.so 的路径
# 详情请参考 https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/README.md
export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH
- 请注意,该项用于在静态图推理时利用Cutlass融合算子提升推理性能,但是并不是必须项。
如果不使用Cutlass可以将`./text_to_image_generation-stable_diffusion_3.py`中的`exp_enable_use_cutlass`设为False。
-
```

高性能推理指令:
Expand All @@ -23,6 +29,8 @@ python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height
--num-inference-steps 50 --inference_optimize 1 \
--benchmark 1
```
注:--inference_optimize 1 用于开启推理优化,--benchmark 1 用于开启性能测试。


- 在 NVIDIA A100-SXM4-40GB 上测试的性能如下:

Expand All @@ -31,28 +39,39 @@ python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height
| 1.2 s | 1.78 s | 4.202 s |


## Paddle Stable Diffusion 3 模型多卡推理:
### batch parallel 实现原理
- 在SD3中,对于输入是一个prompt时,使用CFG需要同时进行unconditional guide和text guide的生成,此时 MM-DiT-blocks 的输入batch_size=2;
所以我们考虑在多卡并行的方案中,将batch为2的输入拆分到两张卡上进行计算,这样单卡的计算量就减少为原来的一半,降低了单卡所承载的浮点计算量。
计算完成后,我们再把两张卡的计算结果 聚合在一起,结果与单卡计算完全一致。
### 开启多卡推理方法
- Paddle Inference 提供了SD3模型的多卡推理功能,用户可以通过设置 `--inference_optimize_bp 1` 来开启这一功能,
使用 `python -m paddle.distributed.launch --gpus 0,1` 指定使用哪些卡进行推理。
高性能多卡推理指令:
## Paddle Stable Diffusion 3 模型多卡推理:
### Data Parallel 实现原理
- 在SD3中,对于输入是一个prompt时,使用CFG需要同时进行unconditional guide和text guide的生成,此时 MM-DiT-blocks 的输入batch_size=2;
所以我们考虑在多卡并行的方案中,将batch为2的输入拆分到两张卡上进行计算,这样单卡的计算量就减少为原来的一半,降低了单卡所承载的浮点计算量。
计算完成后,我们再把两张卡的计算结果聚合在一起,结果与单卡计算完全一致。

### Model parallel 实现原理
- 在SD3中,在Linear和Attnetion中有大量的GEMM(General Matrix Multiply),当生成高分辨率图像时,GEMM的计算量以及模型的预训练权重大小都呈线性递增。
因此,我们考虑在多卡并行方案中,将模型的这些GEMM拆分到两张卡上进行计算,这样单卡的计算量和权重大小就都减少为原来的一半,不仅降低了单卡所承载的浮点计算量,也降低了单卡的显存占用。

### 开启多卡推理方法
- Paddle Inference 提供了SD3模型的多卡推理功能,用户可以通过设置 `mp_size 2` 来开启Model Parallel,使用 `dp_size 2`来开启Data Parallel。
使用 `python -m paddle.distributed.launch --gpus “0,1,2,3”` 指定使用哪些卡进行推理,其中`--gpus “0,1,2,3”`即为启用的GPU卡号。
如果只需使用两卡推理,则只需指定两卡即可,如 `python -m paddle.distributed.launch --gpus “0,1”`。同时需要指定使用的并行方法及并行度,如 `mp_size 2` 或者 `dp_size 2`。

- 注意,这里的`mp_size`需要设定为不大于输入的batch_size个,且`mp_size`和`dp_size`的和不能超过机器总卡数。
- 高性能多卡推理指令:
```shell
# 执行多卡推理指令
python -m paddle.distributed.launch --gpus 0,1 text_to_image_generation-stable_diffusion_3.py \
python -m paddle.distributed.launch --gpus "0,1,2,3" text_to_image_generation-stable_diffusion_3.py \
--dtype float16 \
--height 512 --width 512 \
--num-inference-steps 50 \
--height 1024 \
--width 1024 \
--num-inference-steps 20 \
--inference_optimize 1 \
--inference_optimize_bp 1 \
--mp_size 2 \
--dp_size 2 \
--benchmark 1
```
## 在 NVIDIA A800-SXM4-80GB 上测试的性能如下:
注:--inference_optimize 1 用于开启推理优化,--benchmark 1 用于开启性能测试。

## 在 NVIDIA A800-SXM4-80GB 上测试的性能如下:

| Paddle batch parallel | Paddle Single Card | PyTorch | TensorRT | Paddle 动态图 |
| --------------------- | ------------------ | --------- | -------- | ------------ |
| 0.86 s | 1.2 s | 1.78 s | 1.16 s | 4.202 s |​⬤
| Paddle mp_size=2 & dp_size=2 | Paddle mp_size=2 | Paddle dp_size=2 | Paddle Single Card | Paddle 动态图 |
| ---------------------------- | ------------------- | ---------------- | ------------------ | ------------ |
| 0.99s | 1.581 s | 1.319 s | 2.376 s | 3.2 s |
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import os

import paddle


def parse_args():
parser = argparse.ArgumentParser(
description=" Use PaddleMIX to accelerate the Stable Diffusion3 image generation model."
Expand All @@ -40,6 +43,12 @@ def parse_args():
parser.add_argument("--width", type=int, default=512, help="Width of the generated image.")
parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps.")
parser.add_argument("--dtype", type=str, default="float32", help="Inference data types.")
parser.add_argument(
"--mp_size", type=int, default=1, help="This size refers to the degree of parallelism using model parallel."
)
parser.add_argument(
"--dp_size", type=int, default=1, help="This size refers to the degree of parallelism using data parallel."
)

return parser.parse_args()

Expand All @@ -49,49 +58,42 @@ def parse_args():
if args.inference_optimize:
os.environ["INFERENCE_OPTIMIZE"] = "True"
os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True"
if args.inference_optimize_bp:
os.environ["INFERENCE_OPTIMIZE_BP"] = "True"
if args.dtype == "float32":
inference_dtype = paddle.float32
elif args.dtype == "float16":
inference_dtype = paddle.float16


if args.inference_optimize_bp:
from paddle.distributed import fleet
from paddle.distributed.fleet.utils import recompute
import numpy as np
import random
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
model_parallel_size = 2
data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": data_parallel_size,
"mp_degree": model_parallel_size,
"pp_degree": 1
}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
mp_id = hcg.get_model_parallel_rank()
rank_id = dist.get_rank()
import paddle.distributed as dist
import paddle.distributed.fleet as fleet

strategy = fleet.DistributedStrategy()
model_parallel_size = args.mp_size
data_parallel_size = args.dp_size
strategy.hybrid_configs = {"dp_degree": data_parallel_size, "mp_degree": model_parallel_size, "pp_degree": 1}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
mp_id = hcg.get_model_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
mp_degree = hcg.get_model_parallel_world_size()
dp_degree = hcg.get_data_parallel_world_size()

os.environ["TRITON_KERNEL_CACHE_DIR"] = f"./tmp/sd3_parallel/{rank_id}"

import datetime
from ppdiffusers import StableDiffusion3Pipeline

from ppdiffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
paddle_dtype=inference_dtype,
)

pipe.transformer = paddle.incubate.jit.inference(
pipe.transformer,
save_model_dir="./tmp/sd3",
enable_new_ir=True,
cache_static_model=True,
# V100环境下,需设置exp_enable_use_cutlass=False,
save_model_dir="./tmp/1024_TP_sd3_parallel",
enable_new_ir=False,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)
Expand Down Expand Up @@ -138,8 +140,7 @@ def parse_args():
cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB")

if args.inference_optimize_bp:
if rank_id == 0:
if dp_degree > 1 or mp_degree > 1:
image.save("text_to_image_generation-stable_diffusion_3-result.png")
else:
image.save("text_to_image_generation-stable_diffusion_3-result.png")
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import paddle

from ppdiffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", paddle_dtype=paddle.float16
)
generator = paddle.Generator().manual_seed(42)
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, generator=generator).images[0]
image.save("text_to_image_generation-stable_diffusion_3-result.png")
image.save("text_to_image_generation-stable_diffusion_3-result.png")
114 changes: 85 additions & 29 deletions ppdiffusers/ppdiffusers/models/simplified_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,71 @@
# limitations under the License.

import paddle
import paddle.distributed.fleet as fleet
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear as CPLinear
from paddle.distributed.fleet.meta_parallel import RowParallelLinear as RPLinear
from paddle.nn import LayerList as LayerList

mp_degree = fleet.get_hybrid_communicate_group().get_model_parallel_world_size()


class SimplifiedSD3(nn.Layer):
def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int):
super().__init__()
self.num_layers = num_layers
self.dim = dim
self.head_dim = 64

self.silu = nn.Silu()
self.linear1 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)])
self.linear_context = nn.LayerList(
self.linear1 = LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)])
self.linear_context = LayerList(
[nn.Linear(self.dim, (6 if i < num_layers - 1 else 2) * self.dim) for i in range(num_layers)]
)

self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True)

self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)])
self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers - 1)])
self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)])
self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)])
self.ffn1_context = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)])
self.ffn2_context = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)])
if mp_degree > 1:
self.qkv_mp = LayerList(
[CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]
)
self.eqkv_mp = LayerList(
[CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]
)
self.to_out_linear_mp = LayerList(
[RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]
)
# When using Model Parallel, for the symmetry of GEMM, we change num_layers-1 here to num_layers, which has no effect on the results.
self.to_add_out_linear_mp = LayerList(
[RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]
)

self.ffn1_mp = LayerList(
[CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]
)
self.ffn2_mp = LayerList(
[RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]
)
self.ffn1_context_mp = LayerList(
[CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers - 1)]
)
self.ffn2_context_mp = LayerList(
[
RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True)
for i in range(num_layers - 1)
]
)
else:
self.qkv = LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
self.eqkv = LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
self.to_out_linear = LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)])
# When using Model Parallel, for the symmetry of GEMM, we change num_layers-1 here to num_layers, which has no effect on the results.
self.to_add_out_linear = LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)])

self.ffn1 = LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)])
self.ffn2 = LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)])
self.ffn1_context = LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)])
self.ffn2_context = LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)])

def forward(self, hidden_states, encoder_hidden_states, temb):
print("--------------------this is simplified_sd3------------------------")
Expand Down Expand Up @@ -103,37 +142,49 @@ def forward(self, hidden_states, encoder_hidden_states, temb):
epsilon=1e-06,
)

qkv = self.qkv[i](norm_hidden_states)
eqkv = self.eqkv[i](norm_encoder_hidden_states)
if mp_degree > 1:
qkv = self.qkv_mp[i](norm_hidden_states)
eqkv = self.eqkv_mp[i](norm_encoder_hidden_states)

else:
qkv = self.qkv[i](norm_hidden_states)
eqkv = self.eqkv[i](norm_encoder_hidden_states)

q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv)
bs = hidden_states.shape[0]
q = q.reshape([bs, -1, 24, 64])
k = k.reshape([bs, -1, 24, 64])
v = v.reshape([bs, -1, 24, 64])
head_nums = q.shape[2] // self.head_dim
q = q.reshape([bs, -1, head_nums, self.head_dim])
k = k.reshape([bs, -1, head_nums, self.head_dim])
v = v.reshape([bs, -1, head_nums, self.head_dim])

norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False)
norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, self.dim])
norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, head_nums * self.head_dim])
attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[seq1, seq2], axis=1)

# attn_output, context_attn_output = paddlemix.triton_ops.triton_split(
# norm_hidden_states1, num_or_sections=[1024, 154], axis=1
# )

attn_output = paddle.nn.functional.linear(
attn_output, self.to_out_linear[i].weight, self.to_out_linear[i].bias
)

if not context_pre_only:
if mp_degree > 1:
attn_output = self.to_out_linear_mp[i](attn_output)
context_attn_output = self.to_add_out_linear_mp[i](context_attn_output)
else:
attn_output = self.to_out_linear[i](attn_output)
context_attn_output = self.to_add_out_linear[i](context_attn_output)

hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06
)

# ffn1
ffn_output = self.ffn1[i](norm_hidden_states)
ffn_output = F.gelu(ffn_output, approximate=True)
ffn_output = self.ffn2[i](ffn_output)
if mp_degree > 1:
ffn_output = self.ffn1_mp[i](norm_hidden_states)
ffn_output = F.gelu(ffn_output, approximate=True)
ffn_output = self.ffn2_mp[i](ffn_output)
else:
ffn_output = self.ffn1[i](norm_hidden_states)
ffn_output = F.gelu(ffn_output, approximate=True)
ffn_output = self.ffn2[i](ffn_output)

if context_pre_only:
ffn_output = gate_mlp.unsqueeze(1) * ffn_output
Expand All @@ -149,12 +200,17 @@ def forward(self, hidden_states, encoder_hidden_states, temb):
encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06
)

context_ffn_output = self.ffn1_context[i](norm_encoder_hidden_states)
context_ffn_output = F.gelu(context_ffn_output, approximate=True)
context_ffn_output = self.ffn2_context[i](context_ffn_output)
if mp_degree > 1:
context_ffn_output = self.ffn1_context_mp[i](norm_encoder_hidden_states)
context_ffn_output = F.gelu(context_ffn_output, approximate=True)
context_ffn_output = self.ffn2_context_mp[i](context_ffn_output)
else:
context_ffn_output = self.ffn1_context[i](norm_encoder_hidden_states)
context_ffn_output = F.gelu(context_ffn_output, approximate=True)
context_ffn_output = self.ffn2_context[i](context_ffn_output)

last_context_ffn_output = context_ffn_output
last_context_hidden_states = encoder_hidden_states
last_context_gate_mlp = c_gate_mlp

return hidden_states
return hidden_states
Loading