From 20ed40e8b9cc7ebebd8d4c6eb8ffe106cebb8125 Mon Sep 17 00:00:00 2001 From: bvhari Date: Sun, 1 Sep 2024 12:15:04 +0530 Subject: [PATCH 01/10] Add log-cosh loss --- .../mixin/ModelSetupDiffusionLossMixin.py | 28 +++++++++++++++++++ modules/ui/TrainingTab.py | 21 ++++++++------ modules/util/config/TrainConfig.py | 2 ++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py b/modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py index 9385733f..2a8cdc3f 100644 --- a/modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py +++ b/modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py @@ -56,6 +56,15 @@ def __align_prop_losses( losses = self.__align_prop_loss_fn(data['predicted']) return losses * config.align_prop_weight + + def __log_cosh_loss( + self, + pred: torch.Tensor, + target: torch.Tensor, + ): + diff = pred - target + loss = diff + torch.nn.functional.softplus(-2.0*diff) - torch.log(torch.full(size=diff.size(), fill_value=2.0, dtype=torch.float32, device=diff.device)) + return loss def __masked_losses( self, @@ -91,6 +100,18 @@ def __masked_losses( normalize_masked_area_loss=config.normalize_masked_area_loss, ).mean([1, 2, 3]) * config.mae_strength + # log-cosh Loss + if config.log_cosh_strength != 0: + losses += masked_losses( + losses=self.__log_cosh_loss( + data['predicted'].to(dtype=torch.float32), + data['target'].to(dtype=torch.float32) + ), + mask=batch['latent_mask'].to(dtype=torch.float32), + unmasked_weight=config.unmasked_weight, + normalize_masked_area_loss=config.normalize_masked_area_loss, + ).mean([1, 2, 3]) * config.log_cosh_strength + # VB loss if config.vb_loss_strength != 0 and 'predicted_var_values' in data and self.__coefficients is not None: losses += masked_losses( @@ -133,6 +154,13 @@ def __unmasked_losses( reduction='none' ).mean([1, 2, 3]) * config.mae_strength + # log-cosh Loss + if config.log_cosh_strength != 0: + losses += self.__log_cosh_loss( + data['predicted'].to(dtype=torch.float32), + data['target'].to(dtype=torch.float32) + ).mean([1, 2, 3]) * config.log_cosh_strength + # VB loss if config.vb_loss_strength != 0 and 'predicted_var_values' in data: losses += vb_losses( diff --git a/modules/ui/TrainingTab.py b/modules/ui/TrainingTab.py index daded2b0..f8293a0d 100644 --- a/modules/ui/TrainingTab.py +++ b/modules/ui/TrainingTab.py @@ -652,26 +652,31 @@ def __create_loss_frame(self, master, row, supports_vb_loss: bool = False): tooltip="Mean Absolute Error strength for custom loss settings. MAE + MSE Strengths generally should sum to 1.") components.entry(frame, 1, 1, self.ui_state, "mae_strength") + # log-cosh Strength + components.label(frame, 2, 0, "log-cosh Strength", + tooltip="Log - Hyperbolic cosine Error strength for custom loss settings. Should be used indepedently.") + components.entry(frame, 2, 1, self.ui_state, "log_cosh_strength") + if supports_vb_loss: # VB Strength - components.label(frame, 2, 0, "VB Strength", + components.label(frame, 3, 0, "VB Strength", tooltip="Variational lower-bound strength for custom loss settings. Should be set to 1 for variational diffusion models") - components.entry(frame, 2, 1, self.ui_state, "vb_loss_strength") + components.entry(frame, 3, 1, self.ui_state, "vb_loss_strength") # Loss Weight function - components.label(frame, 3, 0, "Loss Weight Function", + components.label(frame, 4, 0, "Loss Weight Function", tooltip="Choice of loss weight function. Can help the model learn details more accurately.") - components.options(frame, 3, 1, [str(x) for x in list(LossWeight)], self.ui_state, "loss_weight_fn") + components.options(frame, 4, 1, [str(x) for x in list(LossWeight)], self.ui_state, "loss_weight_fn") # Loss weight strength - components.label(frame, 4, 0, "Gamma", + components.label(frame, 5, 0, "Gamma", tooltip="Inverse strength of loss weighting. Range: 1-20, only applies to Min SNR and P2.") - components.entry(frame, 4, 1, self.ui_state, "loss_weight_strength") + components.entry(frame, 5, 1, self.ui_state, "loss_weight_strength") # Loss Scaler - components.label(frame, 5, 0, "Loss Scaler", + components.label(frame, 6, 0, "Loss Scaler", tooltip="Selects the type of loss scaling to use during training. Functionally equated as: Loss * selection") - components.options(frame, 5, 1, [str(x) for x in list(LossScaler)], self.ui_state, "loss_scaler") + components.options(frame, 6, 1, [str(x) for x in list(LossScaler)], self.ui_state, "loss_scaler") def __open_optimizer_params_window(self): window = OptimizerParamsWindow(self.master, self.train_config, self.ui_state) diff --git a/modules/util/config/TrainConfig.py b/modules/util/config/TrainConfig.py index d844500d..702726b5 100644 --- a/modules/util/config/TrainConfig.py +++ b/modules/util/config/TrainConfig.py @@ -283,6 +283,7 @@ class TrainConfig(BaseConfig): align_prop_cfg_scale: float mse_strength: float mae_strength: float + log_cosh_strength: float vb_loss_strength: float loss_weight_fn: LossWeight loss_weight_strength: float @@ -685,6 +686,7 @@ def default_values() -> 'TrainConfig': data.append(("align_prop_cfg_scale", 7.0, float, False)) data.append(("mse_strength", 1.0, float, False)) data.append(("mae_strength", 0.0, float, False)) + data.append(("log_cosh_strength", 0.0, float, False)) data.append(("vb_loss_strength", 1.0, float, False)) data.append(("loss_weight_fn", LossWeight.CONSTANT, LossWeight, False)) data.append(("loss_weight_strength", 5.0, float, False)) From 4908fa119c4611a001e617da7da99425acbc5dfd Mon Sep 17 00:00:00 2001 From: Nerogar Date: Wed, 4 Sep 2024 17:36:02 +0200 Subject: [PATCH 02/10] fixes for the flux LoRA preset --- training_presets/#flux LoRA.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/training_presets/#flux LoRA.json b/training_presets/#flux LoRA.json index 277039d2..f6e86423 100644 --- a/training_presets/#flux LoRA.json +++ b/training_presets/#flux LoRA.json @@ -8,15 +8,15 @@ "output_model_format": "SAFETENSORS", "resolution": "768", "prior": { - "train": false, + "train": true, "weight_dtype": "NFLOAT_4" }, "text_encoder": { - "train": false, - "weight_dtype": "NFLOAT_4" + "train": false }, "text_encoder_2": { - "train": false + "train": false, + "weight_dtype": "NFLOAT_4" }, "training_method": "LORA", "vae": { From 80b93b0a8f443e2d23f1fbd48e63d23759934624 Mon Sep 17 00:00:00 2001 From: Nerogar Date: Wed, 4 Sep 2024 20:06:59 +0200 Subject: [PATCH 03/10] update diffusers --- requirements-global.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-global.txt b/requirements-global.txt index 895d96b7..e176ee53 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -15,7 +15,7 @@ tensorboard==2.17.0 pytorch-lightning==2.2.5 # stable diffusion --e git+https://github.com/huggingface/diffusers.git@5ffbe14#egg=diffusers +-e git+https://github.com/huggingface/diffusers.git@2ee3215#egg=diffusers transformers==4.42.3 omegaconf==2.3.0 # needed to load stable diffusion from single ckpt files invisible-watermark==0.2.0 # needed for the SDXL pipeline From 1e1020e0d9461528a307c88e0395ce3097508dbc Mon Sep 17 00:00:00 2001 From: Nerogar Date: Thu, 5 Sep 2024 22:19:31 +0200 Subject: [PATCH 04/10] Revert "lazy initialization for the dora_scale parameter" This reverts commit 50cf940d --- modules/module/LoRAModule.py | 42 +++++++++++------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/modules/module/LoRAModule.py b/modules/module/LoRAModule.py index dcc722aa..daf4f89d 100644 --- a/modules/module/LoRAModule.py +++ b/modules/module/LoRAModule.py @@ -342,54 +342,38 @@ class DoRAModule(LoRAModule): def __init__(self, *args, **kwargs): self.dora_scale = None - self.dora_scale_initialized = False self.norm_epsilon = kwargs.pop('norm_epsilon', False) super().__init__(*args, **kwargs) - self.register_load_state_dict_post_hook(self.__set_dora_scale_initialized) - - def __set_dora_scale_initialized(self, *args, **kwargs): - self.dora_scale_initialized = True def initialize_weights(self): super().initialize_weights() - self.dora_num_dims = len(self.shape) - 1 + # Thanks to KohakuBlueLeaf once again for figuring out the shape + # wrangling that works for both Linear and Convolutional layers. If you + # were just doing this for Linear, it would be substantially simpler. + orig_weight = get_unquantized_weight(self.orig_module, torch.float) + self.dora_num_dims = orig_weight.dim() - 1 self.dora_scale = nn.Parameter( - torch.ones(size=(self.shape[1],)) - .reshape(self.shape[1], *[1] * self.dora_num_dims) + torch.norm( + orig_weight.transpose(1, 0).reshape(orig_weight.shape[1], -1), + dim=1, keepdim=True) + .reshape(orig_weight.shape[1], *[1] * self.dora_num_dims) .transpose(1, 0) .to(device=self.orig_module.weight.device) ) + del orig_weight + def check_initialized(self): super().check_initialized() assert self.dora_scale is not None def forward(self, x, *args, **kwargs): + self.check_initialized() + A = self.lora_down.weight B = self.lora_up.weight orig_weight = get_unquantized_weight(self.orig_module, A.dtype) - - if not self.dora_scale_initialized: - # Thanks to KohakuBlueLeaf once again for figuring out the shape - # wrangling that works for both Linear and Convolutional layers. If you - # were just doing this for Linear, it would be substantially simpler. - - # dora_scale is not initialized in initialize_weights, because at that point, orig_weight - # could still be located on the wrong device, or not loaded at all. Instead, we use lazy - # initialization during the first forward pass - self.dora_scale = nn.Parameter( - torch.norm( - orig_weight.transpose(1, 0).reshape(orig_weight.shape[1], -1), - dim=1, keepdim=True) - .reshape(orig_weight.shape[1], *[1] * self.dora_num_dims) - .transpose(1, 0) - .to(device=self.orig_module.weight.device) - ) - self.dora_scale_initialized = True - - self.check_initialized() - WP = orig_weight + (self.make_weight(A, B) * (self.alpha / self.rank)) del orig_weight # A norm should never really end up zero at any point, but epsilon just From e895ddc0344d0eb43ca29825c749c9b7dccd93f6 Mon Sep 17 00:00:00 2001 From: Nerogar Date: Thu, 5 Sep 2024 22:34:14 +0200 Subject: [PATCH 05/10] move the model to the train device before calling setup_model --- modules/ui/SampleWindow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui/SampleWindow.py b/modules/ui/SampleWindow.py index a523dc33..fa99da1a 100644 --- a/modules/ui/SampleWindow.py +++ b/modules/ui/SampleWindow.py @@ -123,6 +123,7 @@ def __load_model(self) -> BaseModel: ) model.train_config = self.initial_train_config + model_setup.setup_train_device(model, self.initial_train_config) model_setup.setup_model(model, self.initial_train_config) return model From 86cb13cfea843886a89799d27c48ec96421a8caf Mon Sep 17 00:00:00 2001 From: Nerogar Date: Sat, 7 Sep 2024 00:26:03 +0200 Subject: [PATCH 06/10] add xformers support for Flux --- modules/modelSetup/BaseFluxSetup.py | 51 ++++---- .../flux/FluxXFormersAttnProcessor.py | 112 ++++++++++++++++++ 2 files changed, 139 insertions(+), 24 deletions(-) create mode 100644 modules/modelSetup/flux/FluxXFormersAttnProcessor.py diff --git a/modules/modelSetup/BaseFluxSetup.py b/modules/modelSetup/BaseFluxSetup.py index 1a695062..b1e45f75 100644 --- a/modules/modelSetup/BaseFluxSetup.py +++ b/modules/modelSetup/BaseFluxSetup.py @@ -3,6 +3,7 @@ from modules.model.FluxModel import FluxModel, FluxModelEmbedding from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.modelSetup.flux.FluxXFormersAttnProcessor import FluxXFormersAttnProcessor from modules.modelSetup.mixin.ModelSetupDebugMixin import ModelSetupDebugMixin from modules.modelSetup.mixin.ModelSetupDiffusionLossMixin import ModelSetupDiffusionLossMixin from modules.modelSetup.mixin.ModelSetupEmbeddingMixin import ModelSetupEmbeddingMixin @@ -18,6 +19,7 @@ from modules.util.config.TrainConfig import TrainConfig from modules.util.conv_util import apply_circular_padding_to_conv2d from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context +from modules.util.enum.AttentionMechanism import AttentionMechanism from modules.util.enum.TrainingMethod import TrainingMethod from modules.util.quantization_util import set_nf4_compute_type from modules.util.TrainProgress import TrainProgress @@ -25,6 +27,9 @@ import torch from torch import Tensor +from diffusers.models.attention_processor import FluxAttnProcessor2_0 +from diffusers.utils import is_xformers_available + class BaseFluxSetup( BaseModelSetup, @@ -41,30 +46,28 @@ def _setup_optimizations( model: FluxModel, config: TrainConfig, ): - # if config.attention_mechanism == AttentionMechanism.DEFAULT: - # pass - # # model.transformer.set_attn_processor(AttnProcessor()) - # elif config.attention_mechanism == AttentionMechanism.XFORMERS and is_xformers_available(): - # try: - # # TODO: there is no xformers attention processor like JointAttnProcessor2_0 yet - # # model.transformer.set_attn_processor(XFormersAttnProcessor()) - # model.vae.enable_xformers_memory_efficient_attention() - # except Exception as e: - # print( - # "Could not enable memory efficient attention. Make sure xformers is installed" - # f" correctly and a GPU is available: {e}" - # ) - # elif config.attention_mechanism == AttentionMechanism.SDP: - # model.transformer.set_attn_processor(JointAttnProcessor2_0()) - # - # if is_xformers_available(): - # try: - # model.vae.enable_xformers_memory_efficient_attention() - # except Exception as e: - # print( - # "Could not enable memory efficient attention. Make sure xformers is installed" - # f" correctly and a GPU is available: {e}" - # ) + if config.attention_mechanism == AttentionMechanism.DEFAULT: + model.transformer.set_attn_processor(FluxAttnProcessor2_0()) + elif config.attention_mechanism == AttentionMechanism.XFORMERS and is_xformers_available(): + try: + model.transformer.set_attn_processor(FluxXFormersAttnProcessor(model.train_dtype.torch_dtype())) + model.vae.enable_xformers_memory_efficient_attention() + except Exception as e: + print( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + elif config.attention_mechanism == AttentionMechanism.SDP: + model.transformer.set_attn_processor(FluxAttnProcessor2_0()) + + if is_xformers_available(): + try: + model.vae.enable_xformers_memory_efficient_attention() + except Exception as e: + print( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) if config.gradient_checkpointing.enabled(): enable_checkpointing_for_flux_transformer( diff --git a/modules/modelSetup/flux/FluxXFormersAttnProcessor.py b/modules/modelSetup/flux/FluxXFormersAttnProcessor.py new file mode 100644 index 00000000..71a0282e --- /dev/null +++ b/modules/modelSetup/flux/FluxXFormersAttnProcessor.py @@ -0,0 +1,112 @@ +from typing import Optional + +import torch + +from diffusers.models.attention_processor import Attention, xformers + + +class FluxXFormersAttnProcessor: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, dtype: torch.dtype): + self.dtype = dtype + + def apply_rotary_emb( + self, + x: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None].transpose(1, 2) + sin = sin[None, None].transpose(1, 2) + cos, sin = cos.to(x.device), sin.to(x.device) + + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + if image_rotary_emb is not None: + + query = self.apply_rotary_emb(query, image_rotary_emb) + key = self.apply_rotary_emb(key, image_rotary_emb) + + hidden_states = xformers.ops.memory_efficient_attention( + query.to(dtype=self.dtype), + key.to(dtype=self.dtype), + value.to(dtype=self.dtype), + ) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states From bc74f99ef3a476cc9283df9c0a33394aa10dfb29 Mon Sep 17 00:00:00 2001 From: Nerogar Date: Sat, 7 Sep 2024 00:44:24 +0200 Subject: [PATCH 07/10] fix a deprecation warning for flux --- modules/model/FluxModel.py | 5 ++--- modules/modelSampler/FluxSampler.py | 3 +-- modules/modelSetup/BaseFluxSetup.py | 3 +-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/modules/model/FluxModel.py b/modules/model/FluxModel.py index 84884582..a1ceadcc 100644 --- a/modules/model/FluxModel.py +++ b/modules/model/FluxModel.py @@ -290,16 +290,15 @@ def encode_text( return text_encoder_2_output, pooled_text_encoder_1_output - def prepare_latent_image_ids(self, batch_size, height, width, device, dtype): + def prepare_latent_image_ids(self, height, width, device, dtype): latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) diff --git a/modules/modelSampler/FluxSampler.py b/modules/modelSampler/FluxSampler.py index 94a8eae7..1c241a60 100644 --- a/modules/modelSampler/FluxSampler.py +++ b/modules/modelSampler/FluxSampler.py @@ -103,7 +103,6 @@ def __sample_base( ) image_ids = self.model.prepare_latent_image_ids( - latent_image.shape[0], height // vae_scale_factor, width // vae_scale_factor, self.train_device, @@ -143,7 +142,7 @@ def __sample_base( if "generator" in set(inspect.signature(noise_scheduler.step).parameters.keys()): extra_step_kwargs["generator"] = generator - text_ids = torch.zeros(latent_image.shape[0], prompt_embedding.shape[1], 3, device=self.train_device) + text_ids = torch.zeros(prompt_embedding.shape[1], 3, device=self.train_device) self.model.transformer_to(self.train_device) for i, timestep in enumerate(tqdm(timesteps, desc="sampling")): diff --git a/modules/modelSetup/BaseFluxSetup.py b/modules/modelSetup/BaseFluxSetup.py index b1e45f75..7bced70c 100644 --- a/modules/modelSetup/BaseFluxSetup.py +++ b/modules/modelSetup/BaseFluxSetup.py @@ -447,12 +447,11 @@ def predict( guidance = None text_ids = torch.zeros( - size=(latent_image.shape[0], text_encoder_output.shape[1], 3), + size=(text_encoder_output.shape[1], 3), device=self.train_device, ) image_ids = model.prepare_latent_image_ids( - latent_image.shape[0], latent_input.shape[2], latent_input.shape[3], self.train_device, From ff5e5baad0c690d2fc9903de145ccabb57f4c068 Mon Sep 17 00:00:00 2001 From: Nerogar Date: Sat, 7 Sep 2024 16:29:59 +0200 Subject: [PATCH 08/10] fix ram offloaded checkpointing during sampling --- modules/util/checkpointing_util.py | 43 ++++++++++++++---------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/modules/util/checkpointing_util.py b/modules/util/checkpointing_util.py index 264a22c0..da35586a 100644 --- a/modules/util/checkpointing_util.py +++ b/modules/util/checkpointing_util.py @@ -1,12 +1,13 @@ import inspect -from typing import Callable, Any +from typing import Any, Callable import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from diffusers.models.unets.unet_stable_cascade import SDCascadeAttnBlock, SDCascadeResBlock, SDCascadeTimestepBlock -from torch import nn -from torch.utils.checkpoint import checkpoint from transformers.models.clip.modeling_clip import CLIPEncoderLayer from transformers.models.t5.modeling_t5 import T5Block @@ -32,7 +33,7 @@ def __get_args_indices(fun: Callable, arg_names: list[str]) -> list[int]: signature = dict(inspect.signature(fun).parameters) indices = [] - for i, (key, value) in enumerate(signature.items()): + for i, key in enumerate(signature.keys()): if key in arg_names: indices.append(i) @@ -42,33 +43,22 @@ def __get_args_indices(fun: Callable, arg_names: list[str]) -> list[int]: def to_( data: torch.Tensor | list | tuple | dict, device: torch.device, - include_parameter_indices: list[int] = None, + include_parameter_indices: list[int] | None = None, ): - if include_parameter_indices is not None and len(include_parameter_indices) == 0: - include_parameter_indices = None + if include_parameter_indices is None: + include_parameter_indices = [] if isinstance(data, torch.Tensor): data.data = data.data.to(device=device) elif isinstance(data, (list, tuple)): for i, elem in enumerate(data): - if include_parameter_indices is None or i in include_parameter_indices: + if i in include_parameter_indices: to_(elem, device) elif isinstance(data, dict): for elem in data.values(): to_(elem, device) -def to(data: torch.Tensor | list | tuple | dict, device: torch.device) -> torch.Tensor | list | tuple | dict: - if isinstance(data, torch.Tensor): - return data.to(device=device) - elif isinstance(data, (list, tuple)): - for i in range(len(data)): - data[i] = to(data[i], device) - elif isinstance(data, dict): - for key, elem in data.items(): - data[key] = to(elem, device) - - def create_checkpointed_forward( orig_module: nn.Module, train_device: torch.device, @@ -82,7 +72,7 @@ def create_checkpointed_forward( include_from_offload_param_indices = __get_args_indices(orig_forward, include_from_offload_param_names) if offload_activations: - def custom_forward( + def offloaded_custom_forward( # dummy tensor that requires grad is needed for checkpointing to work when training a LoRA dummy: torch.Tensor = None, *args, @@ -96,6 +86,13 @@ def custom_forward( return output + def custom_forward( + # dummy tensor that requires grad is needed for checkpointing to work when training a LoRA + dummy: torch.Tensor = None, + *args, + ): + return orig_forward(*args) + def forward( *args, **kwargs @@ -107,7 +104,7 @@ def forward( args = __kwargs_to_args(orig_forward, args, kwargs) return checkpoint( - custom_forward, + offloaded_custom_forward, dummy, *args, use_reentrant=True @@ -136,14 +133,14 @@ def forward( dummy.requires_grad_(True) return checkpoint( - custom_forward, + offloaded_custom_forward, dummy, *args, **kwargs, use_reentrant=False ) else: - return custom_forward(None, *args, **kwargs) + return offloaded_custom_forward(None, *args, **kwargs) return forward From 70d882a2c9c21ed0d4c2ed9e2c751eed6de2b260 Mon Sep 17 00:00:00 2001 From: Nerogar Date: Sat, 7 Sep 2024 16:30:09 +0200 Subject: [PATCH 09/10] add XFormers support for SD3 --- .../modelSetup/BaseStableDiffusion3Setup.py | 7 +- .../XFormersJointAttnProcessor.py | 83 +++++++++++++++++++ 2 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 modules/modelSetup/stableDiffusion3/XFormersJointAttnProcessor.py diff --git a/modules/modelSetup/BaseStableDiffusion3Setup.py b/modules/modelSetup/BaseStableDiffusion3Setup.py index 2db47bbe..f3557b25 100644 --- a/modules/modelSetup/BaseStableDiffusion3Setup.py +++ b/modules/modelSetup/BaseStableDiffusion3Setup.py @@ -8,6 +8,7 @@ from modules.modelSetup.mixin.ModelSetupEmbeddingMixin import ModelSetupEmbeddingMixin from modules.modelSetup.mixin.ModelSetupFlowMatchingMixin import ModelSetupFlowMatchingMixin from modules.modelSetup.mixin.ModelSetupNoiseMixin import ModelSetupNoiseMixin +from modules.modelSetup.stableDiffusion3.XFormersJointAttnProcessor import XFormersJointAttnProcessor from modules.module.AdditionalEmbeddingWrapper import AdditionalEmbeddingWrapper from modules.util.checkpointing_util import ( create_checkpointed_forward, @@ -45,12 +46,10 @@ def _setup_optimizations( config: TrainConfig, ): if config.attention_mechanism == AttentionMechanism.DEFAULT: - pass - # model.transformer.set_attn_processor(AttnProcessor()) + model.transformer.set_attn_processor(JointAttnProcessor2_0()) elif config.attention_mechanism == AttentionMechanism.XFORMERS and is_xformers_available(): try: - # TODO: there is no xformers attention processor like JointAttnProcessor2_0 yet - # model.transformer.set_attn_processor(XFormersAttnProcessor()) + model.transformer.set_attn_processor(XFormersJointAttnProcessor(model.train_dtype.torch_dtype())) model.vae.enable_xformers_memory_efficient_attention() except Exception as e: print( diff --git a/modules/modelSetup/stableDiffusion3/XFormersJointAttnProcessor.py b/modules/modelSetup/stableDiffusion3/XFormersJointAttnProcessor.py new file mode 100644 index 00000000..a85c70a4 --- /dev/null +++ b/modules/modelSetup/stableDiffusion3/XFormersJointAttnProcessor.py @@ -0,0 +1,83 @@ +from typing import Optional + +import torch + +from diffusers.models.attention_processor import Attention, xformers + + +class XFormersJointAttnProcessor: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, dtype: torch.dtype): + self.dtype = dtype + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> tuple[torch.FloatTensor, torch.FloatTensor]: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + hidden_states = xformers.ops.memory_efficient_attention( + query.to(dtype=self.dtype), + key.to(dtype=self.dtype), + value.to(dtype=self.dtype), + ) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states From 18d74c4a8dea0aac18aecafa4bc0852d982b1e09 Mon Sep 17 00:00:00 2001 From: Nerogar Date: Sat, 7 Sep 2024 18:46:30 +0200 Subject: [PATCH 10/10] fix checkpointing without offloading --- modules/util/checkpointing_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/util/checkpointing_util.py b/modules/util/checkpointing_util.py index da35586a..487be183 100644 --- a/modules/util/checkpointing_util.py +++ b/modules/util/checkpointing_util.py @@ -133,14 +133,14 @@ def forward( dummy.requires_grad_(True) return checkpoint( - offloaded_custom_forward, + custom_forward, dummy, *args, **kwargs, use_reentrant=False ) else: - return offloaded_custom_forward(None, *args, **kwargs) + return custom_forward(None, *args, **kwargs) return forward