diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 0a9b5293d4a2..76813a4a3495 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -18,7 +18,7 @@ ## 📌 Introduction -ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference) +ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)

@@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below. journal={arXiv}, year={2023} } + +# Distrifusion +@InProceedings{Li_2024_CVPR, + author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song}, + title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month={June}, + year={2024}, + pages={7183-7193} +} ``` diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 1beb86874826..072ddbcfd298 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM): enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation. start_token_size(int): The size of the start tokens, when using StreamingLLM. generated_token_size(int): The size of the generated tokens, When using StreamingLLM. + patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion """ # NOTE: arrange configs according to their importance and frequency of usage @@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM): start_token_size: int = 4 generated_token_size: int = 512 + # Acceleration for Diffusion Model(PipeFusion or Distrifusion) + patched_parallelism_size: int = 1 # for distrifusion + # pipeFusion_m_size: int = 1 # for pipefusion + # pipeFusion_n_size: int = 1 # for pipefusion + def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len self._verify_config() @@ -288,6 +294,14 @@ def _verify_config(self) -> None: # Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit. self.start_token_size = self.block_size + # check Distrifusion + # TODO(@lry89757) need more detailed check + if self.patched_parallelism_size > 1: + # self.use_patched_parallelism = True + self.tp_size = ( + self.patched_parallelism_size + ) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size + # check prompt template if self.prompt_template is None: return @@ -324,6 +338,7 @@ def to_model_shard_inference_config(self) -> "ModelShardInferenceConfig": use_cuda_kernel=self.use_cuda_kernel, use_spec_dec=self.use_spec_dec, use_flash_attn=use_flash_attn, + patched_parallelism_size=self.patched_parallelism_size, ) return model_inference_config @@ -396,6 +411,7 @@ class ModelShardInferenceConfig: use_cuda_kernel: bool = False use_spec_dec: bool = False use_flash_attn: bool = False + patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique @dataclass diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py index 75b9889bf28d..8bed508cba55 100644 --- a/colossalai/inference/core/diffusion_engine.py +++ b/colossalai/inference/core/diffusion_engine.py @@ -11,7 +11,7 @@ from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig -from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe from colossalai.inference.modeling.policy import model_policy_map from colossalai.inference.struct import DiffusionSequence from colossalai.inference.utils import get_model_size, get_model_type diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/layers/diffusion.py similarity index 100% rename from colossalai/inference/modeling/models/diffusion.py rename to colossalai/inference/modeling/layers/diffusion.py diff --git a/colossalai/inference/modeling/layers/distrifusion.py b/colossalai/inference/modeling/layers/distrifusion.py new file mode 100644 index 000000000000..ea97cceefac9 --- /dev/null +++ b/colossalai/inference/modeling/layers/distrifusion.py @@ -0,0 +1,626 @@ +# Code refer and adapted from: +# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers +# https://github.com/PipeFusion/PipeFusion + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from diffusers.models import attention_processor +from diffusers.models.attention import Attention +from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed +from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel +from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel +from torch import nn +from torch.distributed import ProcessGroup + +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.logging import get_dist_logger +from colossalai.shardformer.layer.parallel_module import ParallelModule +from colossalai.utils import get_current_device + +try: + from flash_attn import flash_attn_func + + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + + +logger = get_dist_logger(__name__) + + +# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py +def PixArtAlphaTransformer2DModel_forward( + self: PixArtTransformer2DModel, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, +): + assert hasattr( + self, "patched_parallel_size" + ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`" + + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) + hidden_states = self.pos_embed(hidden_states) + + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk( + 2, dim=1 + ) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=( + -1, + height // self.patched_parallel_size, + width, + self.config.patch_size, + self.config.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height // self.patched_parallel_size * self.config.patch_size, + width * self.config.patch_size, + ) + ) + + # enable Distrifusion Optimization + if hasattr(self, "patched_parallel_size"): + from torch import distributed as dist + + if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape): + self.output_buffer = torch.empty_like(output) + if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape): + self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] + output = output.contiguous() + dist.all_gather(self.buffer_list, output, async_op=False) + torch.cat(self.buffer_list, dim=2, out=self.output_buffer) + output = self.output_buffer + + return (output,) + + +# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py +def SD3Transformer2DModel_forward( + self: SD3Transformer2DModel, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +) -> Union[torch.FloatTensor]: + + assert hasattr( + self, "patched_parallel_size" + ), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`" + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size // self.patched_parallel_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + # enable Distrifusion Optimization + if hasattr(self, "patched_parallel_size"): + from torch import distributed as dist + + if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape): + self.output_buffer = torch.empty_like(output) + if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape): + self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)] + output = output.contiguous() + dist.all_gather(self.buffer_list, output, async_op=False) + torch.cat(self.buffer_list, dim=2, out=self.output_buffer) + output = self.output_buffer + + return (output,) + + +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py +class DistrifusionPatchEmbed(ParallelModule): + def __init__( + self, + module: PatchEmbed, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.module = module + self.rank = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + + @staticmethod + def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs): + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + distrifusion_embed = DistrifusionPatchEmbed( + module, process_group, model_shard_infer_config=model_shard_infer_config + ) + return distrifusion_embed + + def forward(self, latent): + module = self.module + if module.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size + + latent = module.proj(latent) + if module.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if module.layer_norm: + latent = module.norm(latent) + if module.pos_embed is None: + return latent.to(latent.dtype) + # Interpolate or crop positional embeddings as needed + if module.pos_embed_max_size: + pos_embed = module.cropped_pos_embed(height, width) + else: + if module.height != height or module.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=module.pos_embed.shape[-1], + grid_size=(height, width), + base_size=module.base_size, + interpolation_scale=module.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + else: + pos_embed = module.pos_embed + + b, c, h = pos_embed.shape + pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank] + + return (latent + pos_embed).to(latent.dtype) + + +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py +class DistrifusionConv2D(ParallelModule): + + def __init__( + self, + module: nn.Conv2d, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.module = module + self.rank = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + + @staticmethod + def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs): + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config) + return distrifusion_conv + + def sliced_forward(self, x: torch.Tensor) -> torch.Tensor: + + b, c, h, w = x.shape + + stride = self.module.stride[0] + padding = self.module.padding[0] + + output_h = x.shape[2] // stride // self.patched_parallelism_size + idx = dist.get_rank() + h_begin = output_h * idx * stride - padding + h_end = output_h * (idx + 1) * stride + padding + final_padding = [padding, padding, 0, 0] + if h_begin < 0: + h_begin = 0 + final_padding[2] = padding + if h_end > h: + h_end = h + final_padding[3] = padding + sliced_input = x[:, :, h_begin:h_end, :] + padded_input = F.pad(sliced_input, final_padding, mode="constant") + return F.conv2d( + padded_input, + self.module.weight, + self.module.bias, + stride=stride, + padding="valid", + ) + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + output = self.sliced_forward(input) + return output + + +# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py +class DistrifusionFusedAttention(ParallelModule): + + def __init__( + self, + module: attention_processor.Attention, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.counter = 0 + self.module = module + self.buffer_list = None + self.kv_buffer_idx = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + self.handle = None + self.process_group = process_group + self.warm_step = 5 # for warmup + + @staticmethod + def from_native_module( + module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + return DistrifusionFusedAttention( + module=module, + process_group=process_group, + model_shard_infer_config=model_shard_infer_config, + ) + + def _forward( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2) + + if self.patched_parallelism_size == 1: + full_kv = kv + else: + if self.buffer_list is None: # buffer not created + full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1) + elif self.counter <= self.warm_step: + # logger.info(f"warmup: {self.counter}") + dist.all_gather( + self.buffer_list, + kv, + group=self.process_group, + async_op=False, + ) + full_kv = torch.cat(self.buffer_list, dim=1) + else: + # logger.info(f"use old kv to infer: {self.counter}") + self.buffer_list[self.kv_buffer_idx].copy_(kv) + full_kv = torch.cat(self.buffer_list, dim=1) + assert self.handle is None, "we should maintain the kv of last step" + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) + + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/ + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + + if self.handle is not None: + self.handle.wait() + self.handle = None + + b, l, c = hidden_states.shape + kv_shape = (b, l, self.module.to_k.out_features * 2) + if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape): + + self.buffer_list = [ + torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device()) + for _ in range(self.patched_parallelism_size) + ] + + self.counter = 0 + + attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + output = self._forward( + self.module, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + self.counter += 1 + + return output + + +# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py +class DistriSelfAttention(ParallelModule): + def __init__( + self, + module: Attention, + process_group: Union[ProcessGroup, List[ProcessGroup]], + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + super().__init__() + self.counter = 0 + self.module = module + self.buffer_list = None + self.kv_buffer_idx = dist.get_rank(group=process_group) + self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size + self.handle = None + self.process_group = process_group + self.warm_step = 3 # for warmup + + @staticmethod + def from_native_module( + module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + model_shard_infer_config = kwargs.get("model_shard_infer_config", None) + return DistriSelfAttention( + module=module, + process_group=process_group, + model_shard_infer_config=model_shard_infer_config, + ) + + def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0): + attn = self.module + assert isinstance(attn, Attention) + + residual = hidden_states + + batch_size, sequence_length, _ = hidden_states.shape + + query = attn.to_q(hidden_states) + + encoder_hidden_states = hidden_states + k = self.module.to_k(encoder_hidden_states) + v = self.module.to_v(encoder_hidden_states) + kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2) + + if self.patched_parallelism_size == 1: + full_kv = kv + else: + if self.buffer_list is None: # buffer not created + full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1) + elif self.counter <= self.warm_step: + # logger.info(f"warmup: {self.counter}") + dist.all_gather( + self.buffer_list, + kv, + group=self.process_group, + async_op=False, + ) + full_kv = torch.cat(self.buffer_list, dim=1) + else: + # logger.info(f"use old kv to infer: {self.counter}") + self.buffer_list[self.kv_buffer_idx].copy_(kv) + full_kv = torch.cat(self.buffer_list, dim=1) + assert self.handle is None, "we should maintain the kv of last step" + self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True) + + if HAS_FLASH_ATTN: + # flash attn + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype) + else: + # naive attn + key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + *args, + **kwargs, + ) -> torch.FloatTensor: + + # async preallocates memo buffer + if self.handle is not None: + self.handle.wait() + self.handle = None + + b, l, c = hidden_states.shape + kv_shape = (b, l, self.module.to_k.out_features * 2) + if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape): + + self.buffer_list = [ + torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device()) + for _ in range(self.patched_parallelism_size) + ] + + self.counter = 0 + + output = self._forward(hidden_states, scale=scale) + + self.counter += 1 + return output diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py index d5774946e365..cc2bee5efd4d 100644 --- a/colossalai/inference/modeling/models/pixart_alpha.py +++ b/colossalai/inference/modeling/models/pixart_alpha.py @@ -14,7 +14,7 @@ from colossalai.logging import get_dist_logger -from .diffusion import DiffusionPipe +from ..layers.diffusion import DiffusionPipe logger = get_dist_logger(__name__) diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py index d1c63a6dc665..b123164039c8 100644 --- a/colossalai/inference/modeling/models/stablediffusion3.py +++ b/colossalai/inference/modeling/models/stablediffusion3.py @@ -4,7 +4,7 @@ import torch from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps -from .diffusion import DiffusionPipe +from ..layers.diffusion import DiffusionPipe # TODO(@lry89757) temporarily image, please support more return output diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py index 356056ba73e7..1150b2432cc5 100644 --- a/colossalai/inference/modeling/policy/pixart_alpha.py +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -1,9 +1,17 @@ +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel from torch import nn from colossalai.inference.config import RPC_PARAM -from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.distrifusion import ( + DistrifusionConv2D, + DistrifusionPatchEmbed, + DistriSelfAttention, + PixArtAlphaTransformer2DModel_forward, +) from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward -from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class PixArtAlphaInferPolicy(Policy, RPC_PARAM): @@ -12,9 +20,46 @@ def __init__(self) -> None: def module_policy(self): policy = {} + + if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1: + + policy[PixArtTransformer2DModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="pos_embed.proj", + target_module=DistrifusionConv2D, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + SubModuleReplacementDescription( + suffix="pos_embed", + target_module=DistrifusionPatchEmbed, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + ], + attribute_replacement={ + "patched_parallel_size": self.shard_config.extra_kwargs[ + "model_shard_infer_config" + ].patched_parallelism_size + }, + method_replacement={"forward": PixArtAlphaTransformer2DModel_forward}, + ) + + policy[BasicTransformerBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn1", + target_module=DistriSelfAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, + ) + ] + ) + self.append_or_create_method_replacement( description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe ) + return policy def preprocess(self) -> nn.Module: diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py index c9877f7dcae6..39b764b92887 100644 --- a/colossalai/inference/modeling/policy/stablediffusion3.py +++ b/colossalai/inference/modeling/policy/stablediffusion3.py @@ -1,9 +1,17 @@ +from diffusers.models.attention import JointTransformerBlock +from diffusers.models.transformers import SD3Transformer2DModel from torch import nn from colossalai.inference.config import RPC_PARAM -from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.diffusion import DiffusionPipe +from colossalai.inference.modeling.layers.distrifusion import ( + DistrifusionConv2D, + DistrifusionFusedAttention, + DistrifusionPatchEmbed, + SD3Transformer2DModel_forward, +) from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward -from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class StableDiffusion3InferPolicy(Policy, RPC_PARAM): @@ -12,6 +20,42 @@ def __init__(self) -> None: def module_policy(self): policy = {} + + if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1: + + policy[SD3Transformer2DModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="pos_embed.proj", + target_module=DistrifusionConv2D, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + SubModuleReplacementDescription( + suffix="pos_embed", + target_module=DistrifusionPatchEmbed, + kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]}, + ), + ], + attribute_replacement={ + "patched_parallel_size": self.shard_config.extra_kwargs[ + "model_shard_infer_config" + ].patched_parallelism_size + }, + method_replacement={"forward": SD3Transformer2DModel_forward}, + ) + + policy[JointTransformerBlock] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn", + target_module=DistrifusionFusedAttention, + kwargs={ + "model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"], + }, + ) + ] + ) + self.append_or_create_method_replacement( description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe ) diff --git a/examples/inference/stable_diffusion/README.md b/examples/inference/stable_diffusion/README.md new file mode 100644 index 000000000000..c11b9804392c --- /dev/null +++ b/examples/inference/stable_diffusion/README.md @@ -0,0 +1,22 @@ +## File Structure +``` +|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model. +|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion +|- benchmark_sd3.py: benchmark the performance of our InferenceEngine +|- run_benchmark.sh: run benchmark command +``` +Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/` + +## Run Inference + +The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3. + +For a basic setting, you could run the example by: +```bash +colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world" +``` + +Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs: +```bash +colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL +``` diff --git a/examples/inference/stable_diffusion/benchmark_sd3.py b/examples/inference/stable_diffusion/benchmark_sd3.py new file mode 100644 index 000000000000..19db57c33c82 --- /dev/null +++ b/examples/inference/stable_diffusion/benchmark_sd3.py @@ -0,0 +1,179 @@ +import argparse +import json +import time +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from diffusers import DiffusionPipeline + +import colossalai +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +GIGABYTE = 1024**3 +MEGABYTE = 1024 * 1024 + +_DTYPE_MAPPING = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} + + +def log_generation_time(log_data, log_file): + with open(log_file, "a") as f: + json.dump(log_data, f, indent=2) + f.write("\n") + + +def warmup(engine, args): + for _ in range(args.n_warm_up_steps): + engine.generate( + prompts=["hello world"], + generation_config=DiffusionGenerationConfig( + num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0] + ), + ) + + +def profile_context(args): + return ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) + if args.profile + else nullcontext() + ) + + +def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None): + log_data = { + "mode": mode, + "model": model_name, + "batch_size": args.batch_size, + "patched_parallel_size": args.patched_parallel_size, + "num_inference_steps": args.num_inference_steps, + "height": h, + "width": w, + "dtype": args.dtype, + "profile": args.profile, + "n_warm_up_steps": args.n_warm_up_steps, + "n_repeat_times": args.n_repeat_times, + "avg_generation_time": avg_time, + "log_message": log_msg, + } + + if args.log: + log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json" + log_generation_time(log_data=log_data, log_file=log_file) + + if args.profile: + file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json" + prof.export_chrome_trace(file) + + +def benchmark_colossalai(rank, world_size, port, args): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + from colossalai.cluster.dist_coordinator import DistCoordinator + + coordinator = DistCoordinator() + + inference_config = InferenceConfig( + dtype=args.dtype, + patched_parallelism_size=args.patched_parallel_size, + ) + engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False) + + warmup(engine, args) + + for h, w in zip(args.height, args.width): + with profile_context(args) as prof: + start = time.perf_counter() + for _ in range(args.n_repeat_times): + engine.generate( + prompts=["hello world"], + generation_config=DiffusionGenerationConfig( + num_inference_steps=args.num_inference_steps, height=h, width=w + ), + ) + end = time.perf_counter() + + avg_time = (end - start) / args.n_repeat_times + log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s" + coordinator.print_on_master(log_msg) + + if dist.get_rank() == 0: + log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof) + + +def benchmark_diffusers(args): + model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda") + + for _ in range(args.n_warm_up_steps): + model( + prompt="hello world", + num_inference_steps=args.num_inference_steps, + height=args.height[0], + width=args.width[0], + ) + + for h, w in zip(args.height, args.width): + with profile_context(args) as prof: + start = time.perf_counter() + for _ in range(args.n_repeat_times): + model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w) + end = time.perf_counter() + + avg_time = (end - start) / args.n_repeat_times + log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s" + print(log_msg) + + log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def benchmark(args): + if args.mode == "colossalai": + spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args) + elif args.mode == "diffusers": + benchmark_diffusers(args) + + +""" +# enable log +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log + +# enable profiler +python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20 +""" + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size") + parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps") + parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list") + parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list") + parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps") + parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times") + parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler") + parser.add_argument("--log", default=False, action="store_true", help="Enable logging") + parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path") + parser.add_argument( + "--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode" + ) + args = parser.parse_args() + benchmark(args) diff --git a/examples/inference/stable_diffusion/compute_metric.py b/examples/inference/stable_diffusion/compute_metric.py new file mode 100644 index 000000000000..14c92501b66d --- /dev/null +++ b/examples/inference/stable_diffusion/compute_metric.py @@ -0,0 +1,80 @@ +# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py +import argparse +import os + +import numpy as np +import torch +from cleanfid import fid +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio +from torchvision.transforms import Resize +from tqdm import tqdm + + +def read_image(path: str): + """ + input: path + output: tensor (C, H, W) + """ + img = np.asarray(Image.open(path)) + if len(img.shape) == 2: + img = np.repeat(img[:, :, None], 3, axis=2) + img = torch.from_numpy(img).permute(2, 0, 1) + return img + + +class MultiImageDataset(Dataset): + def __init__(self, root0, root1, is_gt=False): + super().__init__() + self.root0 = root0 + self.root1 = root1 + file_names0 = os.listdir(root0) + file_names1 = os.listdir(root1) + + self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")]) + self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")]) + self.is_gt = is_gt + assert len(self.image_names0) == len(self.image_names1) + + def __len__(self): + return len(self.image_names0) + + def __getitem__(self, idx): + img0 = read_image(os.path.join(self.root0, self.image_names0[idx])) + if self.is_gt: + # resize to 1024 x 1024 + img0 = Resize((1024, 1024))(img0) + img1 = read_image(os.path.join(self.root1, self.image_names1[idx])) + + batch_list = [img0, img1] + return batch_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--is_gt", action="store_true") + parser.add_argument("--input_root0", type=str, required=True) + parser.add_argument("--input_root1", type=str, required=True) + args = parser.parse_args() + + psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda") + lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda") + + dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt) + dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) + + progress_bar = tqdm(dataloader) + with torch.inference_mode(): + for i, batch in enumerate(progress_bar): + batch = [img.to("cuda") / 255 for img in batch] + batch_size = batch[0].shape[0] + psnr.update(batch[0], batch[1]) + lpips.update(batch[0], batch[1]) + fid_score = fid.compute_fid(args.input_root0, args.input_root1) + + print("PSNR:", psnr.compute().item()) + print("LPIPS:", lpips.compute().item()) + print("FID:", fid_score) diff --git a/examples/inference/stable_diffusion/requirements.txt b/examples/inference/stable_diffusion/requirements.txt new file mode 100644 index 000000000000..c4e74162dfb5 --- /dev/null +++ b/examples/inference/stable_diffusion/requirements.txt @@ -0,0 +1,3 @@ +torchvision +torchmetrics +cleanfid diff --git a/examples/inference/stable_diffusion/run_benchmark.sh b/examples/inference/stable_diffusion/run_benchmark.sh new file mode 100644 index 000000000000..f3e45a335219 --- /dev/null +++ b/examples/inference/stable_diffusion/run_benchmark.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers") +parallelism=(1 2 4 8) +resolutions=(1024 2048 3840) +modes=("colossalai" "diffusers") + +CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +for model in "${models[@]}"; do + for p in "${parallelism[@]}"; do + for resolution in "${resolutions[@]}"; do + for mode in "${modes[@]}"; do + if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then + continue + fi + if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then + continue + fi + CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p + + cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution" + + echo "Executing: $cmd" + eval $cmd + done + done + done +done diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py index fe989eed7c2d..9e146c34b937 100644 --- a/examples/inference/stable_diffusion/sd3_generation.py +++ b/examples/inference/stable_diffusion/sd3_generation.py @@ -1,18 +1,17 @@ import argparse -from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline -from torch import bfloat16, float16, float32 +from diffusers import DiffusionPipeline +from torch import bfloat16 +from torch import distributed as dist +from torch import float16, float32 import colossalai from colossalai.cluster import DistCoordinator from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig from colossalai.inference.core.engine import InferenceEngine -from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy -from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy # For Stable Diffusion 3, we'll use the following configuration -MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0] -POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0] +MODEL_CLS = DiffusionPipeline TORCH_DTYPE_MAP = { "fp16": float16, @@ -43,20 +42,27 @@ def infer(args): max_batch_size=args.max_batch_size, tp_size=args.tp_size, use_cuda_kernel=args.use_cuda_kernel, + patched_parallelism_size=dist.get_world_size(), ) - engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True) + engine = InferenceEngine(model, inference_config=inference_config, verbose=True) # ============================== # Generation # ============================== coordinator.print_on_master(f"Generating...") out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0] - out.save("cat.jpg") + if dist.get_rank() == 0: + out.save(f"cat_parallel_size{dist.get_world_size()}.jpg") coordinator.print_on_master(out) # colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH + # colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 +# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 + +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 +# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1 if __name__ == "__main__":