Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
celll1 committed Sep 8, 2024
2 parents e157d75 + 1b749e2 commit 40fd26b
Show file tree
Hide file tree
Showing 14 changed files with 309 additions and 98 deletions.
5 changes: 2 additions & 3 deletions modules/model/FluxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,16 +289,15 @@ def encode_text(

return text_encoder_2_output, pooled_text_encoder_1_output

def prepare_latent_image_ids(self, batch_size, height, width, device, dtype):
def prepare_latent_image_ids(self, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)

return latent_image_ids.to(device=device, dtype=dtype)
Expand Down
3 changes: 1 addition & 2 deletions modules/modelSampler/FluxSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __sample_base(
)

image_ids = self.model.prepare_latent_image_ids(
latent_image.shape[0],
height // vae_scale_factor,
width // vae_scale_factor,
self.train_device,
Expand Down Expand Up @@ -143,7 +142,7 @@ def __sample_base(
if "generator" in set(inspect.signature(noise_scheduler.step).parameters.keys()):
extra_step_kwargs["generator"] = generator

text_ids = torch.zeros(latent_image.shape[0], prompt_embedding.shape[1], 3, device=self.train_device)
text_ids = torch.zeros(prompt_embedding.shape[1], 3, device=self.train_device)

self.model.transformer_to(self.train_device)
for i, timestep in enumerate(tqdm(timesteps, desc="sampling")):
Expand Down
54 changes: 28 additions & 26 deletions modules/modelSetup/BaseFluxSetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from modules.model.FluxModel import FluxModel, FluxModelEmbedding
from modules.modelSetup.BaseModelSetup import BaseModelSetup
from modules.modelSetup.flux.FluxXFormersAttnProcessor import FluxXFormersAttnProcessor
from modules.modelSetup.mixin.ModelSetupDebugMixin import ModelSetupDebugMixin
from modules.modelSetup.mixin.ModelSetupDiffusionLossMixin import ModelSetupDiffusionLossMixin
from modules.modelSetup.mixin.ModelSetupEmbeddingMixin import ModelSetupEmbeddingMixin
Expand All @@ -18,13 +19,17 @@
from modules.util.config.TrainConfig import TrainConfig
from modules.util.conv_util import apply_circular_padding_to_conv2d
from modules.util.dtype_util import create_autocast_context, disable_fp16_autocast_context
from modules.util.enum.AttentionMechanism import AttentionMechanism
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.quantization_util import set_nf4_compute_type
from modules.util.TrainProgress import TrainProgress

import torch
from torch import Tensor

from diffusers.models.attention_processor import FluxAttnProcessor2_0
from diffusers.utils import is_xformers_available


class BaseFluxSetup(
BaseModelSetup,
Expand All @@ -41,30 +46,28 @@ def _setup_optimizations(
model: FluxModel,
config: TrainConfig,
):
# if config.attention_mechanism == AttentionMechanism.DEFAULT:
# pass
# # model.transformer.set_attn_processor(AttnProcessor())
# elif config.attention_mechanism == AttentionMechanism.XFORMERS and is_xformers_available():
# try:
# # TODO: there is no xformers attention processor like JointAttnProcessor2_0 yet
# # model.transformer.set_attn_processor(XFormersAttnProcessor())
# model.vae.enable_xformers_memory_efficient_attention()
# except Exception as e:
# print(
# "Could not enable memory efficient attention. Make sure xformers is installed"
# f" correctly and a GPU is available: {e}"
# )
# elif config.attention_mechanism == AttentionMechanism.SDP:
# model.transformer.set_attn_processor(JointAttnProcessor2_0())
#
# if is_xformers_available():
# try:
# model.vae.enable_xformers_memory_efficient_attention()
# except Exception as e:
# print(
# "Could not enable memory efficient attention. Make sure xformers is installed"
# f" correctly and a GPU is available: {e}"
# )
if config.attention_mechanism == AttentionMechanism.DEFAULT:
model.transformer.set_attn_processor(FluxAttnProcessor2_0())
elif config.attention_mechanism == AttentionMechanism.XFORMERS and is_xformers_available():
try:
model.transformer.set_attn_processor(FluxXFormersAttnProcessor(model.train_dtype.torch_dtype()))
model.vae.enable_xformers_memory_efficient_attention()
except Exception as e:
print(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
elif config.attention_mechanism == AttentionMechanism.SDP:
model.transformer.set_attn_processor(FluxAttnProcessor2_0())

if is_xformers_available():
try:
model.vae.enable_xformers_memory_efficient_attention()
except Exception as e:
print(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)

if config.gradient_checkpointing.enabled():
enable_checkpointing_for_flux_transformer(
Expand Down Expand Up @@ -444,12 +447,11 @@ def predict(
guidance = None

text_ids = torch.zeros(
size=(latent_image.shape[0], text_encoder_output.shape[1], 3),
size=(text_encoder_output.shape[1], 3),
device=self.train_device,
)

image_ids = model.prepare_latent_image_ids(
latent_image.shape[0],
latent_input.shape[2],
latent_input.shape[3],
self.train_device,
Expand Down
7 changes: 3 additions & 4 deletions modules/modelSetup/BaseStableDiffusion3Setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modules.modelSetup.mixin.ModelSetupEmbeddingMixin import ModelSetupEmbeddingMixin
from modules.modelSetup.mixin.ModelSetupFlowMatchingMixin import ModelSetupFlowMatchingMixin
from modules.modelSetup.mixin.ModelSetupNoiseMixin import ModelSetupNoiseMixin
from modules.modelSetup.stableDiffusion3.XFormersJointAttnProcessor import XFormersJointAttnProcessor
from modules.module.AdditionalEmbeddingWrapper import AdditionalEmbeddingWrapper
from modules.util.checkpointing_util import (
create_checkpointed_forward,
Expand Down Expand Up @@ -45,12 +46,10 @@ def _setup_optimizations(
config: TrainConfig,
):
if config.attention_mechanism == AttentionMechanism.DEFAULT:
pass
# model.transformer.set_attn_processor(AttnProcessor())
model.transformer.set_attn_processor(JointAttnProcessor2_0())
elif config.attention_mechanism == AttentionMechanism.XFORMERS and is_xformers_available():
try:
# TODO: there is no xformers attention processor like JointAttnProcessor2_0 yet
# model.transformer.set_attn_processor(XFormersAttnProcessor())
model.transformer.set_attn_processor(XFormersJointAttnProcessor(model.train_dtype.torch_dtype()))
model.vae.enable_xformers_memory_efficient_attention()
except Exception as e:
print(
Expand Down
112 changes: 112 additions & 0 deletions modules/modelSetup/flux/FluxXFormersAttnProcessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Optional

import torch

from diffusers.models.attention_processor import Attention, xformers


class FluxXFormersAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self, dtype: torch.dtype):
self.dtype = dtype

def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None].transpose(1, 2)
sin = sin[None, None].transpose(1, 2)
cos, sin = cos.to(x.device), sin.to(x.device)

x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)

out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)

return out

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)

if image_rotary_emb is not None:

query = self.apply_rotary_emb(query, image_rotary_emb)
key = self.apply_rotary_emb(key, image_rotary_emb)

hidden_states = xformers.ops.memory_efficient_attention(
query.to(dtype=self.dtype),
key.to(dtype=self.dtype),
value.to(dtype=self.dtype),
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states
else:
return hidden_states
28 changes: 28 additions & 0 deletions modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def __align_prop_losses(
losses = self.__align_prop_loss_fn(data['predicted'])

return losses * config.align_prop_weight

def __log_cosh_loss(
self,
pred: torch.Tensor,
target: torch.Tensor,
):
diff = pred - target
loss = diff + torch.nn.functional.softplus(-2.0*diff) - torch.log(torch.full(size=diff.size(), fill_value=2.0, dtype=torch.float32, device=diff.device))
return loss

def __masked_losses(
self,
Expand Down Expand Up @@ -91,6 +100,18 @@ def __masked_losses(
normalize_masked_area_loss=config.normalize_masked_area_loss,
).mean([1, 2, 3]) * config.mae_strength

# log-cosh Loss
if config.log_cosh_strength != 0:
losses += masked_losses(
losses=self.__log_cosh_loss(
data['predicted'].to(dtype=torch.float32),
data['target'].to(dtype=torch.float32)
),
mask=batch['latent_mask'].to(dtype=torch.float32),
unmasked_weight=config.unmasked_weight,
normalize_masked_area_loss=config.normalize_masked_area_loss,
).mean([1, 2, 3]) * config.log_cosh_strength

# VB loss
if config.vb_loss_strength != 0 and 'predicted_var_values' in data and self.__coefficients is not None:
losses += masked_losses(
Expand Down Expand Up @@ -133,6 +154,13 @@ def __unmasked_losses(
reduction='none'
).mean([1, 2, 3]) * config.mae_strength

# log-cosh Loss
if config.log_cosh_strength != 0:
losses += self.__log_cosh_loss(
data['predicted'].to(dtype=torch.float32),
data['target'].to(dtype=torch.float32)
).mean([1, 2, 3]) * config.log_cosh_strength

# VB loss
if config.vb_loss_strength != 0 and 'predicted_var_values' in data:
losses += vb_losses(
Expand Down
83 changes: 83 additions & 0 deletions modules/modelSetup/stableDiffusion3/XFormersJointAttnProcessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Optional

import torch

from diffusers.models.attention_processor import Attention, xformers


class XFormersJointAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self, dtype: torch.dtype):
self.dtype = dtype

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
residual = hidden_states

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size = encoder_hidden_states.shape[0]

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)

hidden_states = xformers.ops.memory_efficient_attention(
query.to(dtype=self.dtype),
key.to(dtype=self.dtype),
value.to(dtype=self.dtype),
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

return hidden_states, encoder_hidden_states
Loading

0 comments on commit 40fd26b

Please sign in to comment.