-
-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/master' into dev
- Loading branch information
Showing
14 changed files
with
309 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from diffusers.models.attention_processor import Attention, xformers | ||
|
||
|
||
class FluxXFormersAttnProcessor: | ||
"""Attention processor used typically in processing the SD3-like self-attention projections.""" | ||
|
||
def __init__(self, dtype: torch.dtype): | ||
self.dtype = dtype | ||
|
||
def apply_rotary_emb( | ||
self, | ||
x: torch.Tensor, | ||
freqs_cis: tuple[torch.Tensor, torch.Tensor], | ||
) -> torch.Tensor: | ||
cos, sin = freqs_cis # [S, D] | ||
cos = cos[None, None].transpose(1, 2) | ||
sin = sin[None, None].transpose(1, 2) | ||
cos, sin = cos.to(x.device), sin.to(x.device) | ||
|
||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] | ||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | ||
|
||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) | ||
|
||
return out | ||
|
||
def __call__( | ||
self, | ||
attn: Attention, | ||
hidden_states: torch.FloatTensor, | ||
encoder_hidden_states: torch.FloatTensor = None, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] = None, | ||
) -> torch.FloatTensor: | ||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
|
||
# `sample` projections. | ||
query = attn.to_q(hidden_states) | ||
key = attn.to_k(hidden_states) | ||
value = attn.to_v(hidden_states) | ||
|
||
inner_dim = key.shape[-1] | ||
head_dim = inner_dim // attn.heads | ||
|
||
query = query.view(batch_size, -1, attn.heads, head_dim) | ||
key = key.view(batch_size, -1, attn.heads, head_dim) | ||
value = value.view(batch_size, -1, attn.heads, head_dim) | ||
|
||
if attn.norm_q is not None: | ||
query = attn.norm_q(query) | ||
if attn.norm_k is not None: | ||
key = attn.norm_k(key) | ||
|
||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` | ||
if encoder_hidden_states is not None: | ||
# `context` projections. | ||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) | ||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | ||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | ||
|
||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
) | ||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
) | ||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
) | ||
|
||
if attn.norm_added_q is not None: | ||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) | ||
if attn.norm_added_k is not None: | ||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) | ||
|
||
# attention | ||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) | ||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) | ||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) | ||
|
||
if image_rotary_emb is not None: | ||
|
||
query = self.apply_rotary_emb(query, image_rotary_emb) | ||
key = self.apply_rotary_emb(key, image_rotary_emb) | ||
|
||
hidden_states = xformers.ops.memory_efficient_attention( | ||
query.to(dtype=self.dtype), | ||
key.to(dtype=self.dtype), | ||
value.to(dtype=self.dtype), | ||
) | ||
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) | ||
hidden_states = hidden_states.to(query.dtype) | ||
|
||
if encoder_hidden_states is not None: | ||
encoder_hidden_states, hidden_states = ( | ||
hidden_states[:, : encoder_hidden_states.shape[1]], | ||
hidden_states[:, encoder_hidden_states.shape[1] :], | ||
) | ||
|
||
# linear proj | ||
hidden_states = attn.to_out[0](hidden_states) | ||
# dropout | ||
hidden_states = attn.to_out[1](hidden_states) | ||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | ||
|
||
return hidden_states, encoder_hidden_states | ||
else: | ||
return hidden_states |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
modules/modelSetup/stableDiffusion3/XFormersJointAttnProcessor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from diffusers.models.attention_processor import Attention, xformers | ||
|
||
|
||
class XFormersJointAttnProcessor: | ||
"""Attention processor used typically in processing the SD3-like self-attention projections.""" | ||
|
||
def __init__(self, dtype: torch.dtype): | ||
self.dtype = dtype | ||
|
||
def __call__( | ||
self, | ||
attn: Attention, | ||
hidden_states: torch.FloatTensor, | ||
encoder_hidden_states: torch.FloatTensor = None, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
*args, | ||
**kwargs, | ||
) -> tuple[torch.FloatTensor, torch.FloatTensor]: | ||
residual = hidden_states | ||
|
||
input_ndim = hidden_states.ndim | ||
if input_ndim == 4: | ||
batch_size, channel, height, width = hidden_states.shape | ||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | ||
context_input_ndim = encoder_hidden_states.ndim | ||
if context_input_ndim == 4: | ||
batch_size, channel, height, width = encoder_hidden_states.shape | ||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | ||
|
||
batch_size = encoder_hidden_states.shape[0] | ||
|
||
# `sample` projections. | ||
query = attn.to_q(hidden_states) | ||
key = attn.to_k(hidden_states) | ||
value = attn.to_v(hidden_states) | ||
|
||
# `context` projections. | ||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) | ||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | ||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | ||
|
||
# attention | ||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) | ||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) | ||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) | ||
|
||
inner_dim = key.shape[-1] | ||
head_dim = inner_dim // attn.heads | ||
query = query.view(batch_size, -1, attn.heads, head_dim) | ||
key = key.view(batch_size, -1, attn.heads, head_dim) | ||
value = value.view(batch_size, -1, attn.heads, head_dim) | ||
|
||
hidden_states = xformers.ops.memory_efficient_attention( | ||
query.to(dtype=self.dtype), | ||
key.to(dtype=self.dtype), | ||
value.to(dtype=self.dtype), | ||
) | ||
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) | ||
hidden_states = hidden_states.to(query.dtype) | ||
|
||
# Split the attention outputs. | ||
hidden_states, encoder_hidden_states = ( | ||
hidden_states[:, : residual.shape[1]], | ||
hidden_states[:, residual.shape[1] :], | ||
) | ||
|
||
# linear proj | ||
hidden_states = attn.to_out[0](hidden_states) | ||
# dropout | ||
hidden_states = attn.to_out[1](hidden_states) | ||
if not attn.context_pre_only: | ||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | ||
|
||
if input_ndim == 4: | ||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | ||
if context_input_ndim == 4: | ||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | ||
|
||
return hidden_states, encoder_hidden_states |
Oops, something went wrong.