-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for attention masking in Flux #5942
Conversation
This is great, what do you think about adding attention masking support to the flux code in ComfyUI directly? For other models the attention mask/bias is passed as part of the "conditioning" object: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_base.py#L796 , this also means it could be added as an extra option to the existing apply style model node instead of needing a new node. |
Glad to take the opportunity to make it more flexible, and I'm confident that I understand the conditioning code well enough. @comfyanonymous do you think this is too convoluted? If not, I'll go ahead and implement it. Another option would be to do like xformers does, and add some simple composable types representing block diagonal, causal, etc. masks for common mask forms with potentially unknown shapes, which could then be "realized" into tensors once the image size is known. |
The conditioning can contain the attention mask for only the conditioning itself. Then here https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_base.py#L771 you can process the attention mask. Here's an example how the "attention_mask" in the conditioning is passed as the "text_embedding_mask" argument of the hunyuan-dit model: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_base.py#L702 If you need the shape of the latent: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_base.py#L663 You can also do it directly in the model code. |
Yes, I see now how the extra_conds function can modify the attention mask. The problem is that most of the applications for attention masking involve modifying attention between certain parts of the image and certain parts of the text (for example, my use case is to make the image pay less attention to the redux tokens, but I expect that most people will use this for efficient regional prompting), so the conditioning can't only contain the conditioning attention mask. The full attention mask for flux (and other mmdits) has shape (n_tokens_txt + n_tokens_img, n_tokens_txt + n_tokens_img), and the part of the mask corresponding to the image is a flattened version of the 2d patches, so it can't be interpolated to the right size without knowing the original shape. In any case, I've gone for the implementation I mentioned where the mask is created with a "placeholder" image size, and then interpolated by the extra_conds function. I've also moved the redux weighting code to the normal style application node. Let me know what you think of this architecture -- I know it's a little convoluted. |
This corrects a weird inconsistency with skip_reshape. It also allows masks of various shapes to be passed, which will be automtically expanded (in a memory-efficient way) to a size that is compatible with xformers or pytorch sdpa respectively.
I get an error when running ltx video with xformers, the first workflow on this page: https://comfyanonymous.github.io/ComfyUI_examples/ltxv/ |
Fixed. I think the modified pytorch/xformers attention functions should now be compatible with all reasonable mask shapes ( |
I have some questions on how to use this. Some weeks ago I tried to create a version of the Apply Style Model node that took an extra mask image as input. The node used the image to approximate an attention mask for the DiT patches, then feed it into Flux model... it kind of worked, but with some funny results (like if painted a region with the mask it resulted in a "hole" in the output, an effect similar to a dumb in-painting algorithm over the mask hole). Maybe I was doing something wrong in the code. |
Hello @recris, The advantage of this approach is that you can set the attention weight arbitrarily between any image token and any text token. I'll give here an example of how to mask out part of the redux image. The effect would be that the entire generated image would be unable to see a specific part of the redux image. I would however note that some information from the masked region of the redux image will still leak out, since the tokens going into the redux model are encoded siglip tokens, which necessarily contain global information as well as local. # cond here is the output from redux
cond = ...
n = cond.shape[1] # = 29*29
c_out = []
# replace this with your mask.
# it should stretched to 29x29, since siglip outputs 29x29 patches.
image_mask = torch.ones((29, 29))
# false -> -inf, true -> 0
image_mask = torch.log(image_mask.flatten())
for t in conditioning:
(txt, keys) = t
keys = keys.copy()
# -inf
# since we dont care about masking specific parts of the *generated* image, we can just use mask_ref_size = (1, 1)
mask_ref_size = (1, 1)
n_ref = mask_ref_size[0] * mask_ref_size[1]
n_txt = txt.shape[1]
# create a new mask for the attention calls inside flux
mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
# now fill in the attention bias to our redux tokens
# make the text tokens pay less attention to the specified redux tokens
mask[:, :n_txt, n_txt:n_txt+n] = image_mask
# and then make the image tokens pay less attention to the specified redux tokens
mask[:, n_txt+n:, n_txt:n_txt+n] = image_mask
keys["attention_mask"] = new_mask.to(txt.device)
keys["attention_mask_img_shape"] = mask_ref_size
c_out.append([torch.cat((txt, cond), dim=1), keys])
return (c_out, ) |
FYI regarding the issue mentioned at gseth/ControlAltAI-Nodes#13 , |
Yes, custom node creators who override the forward_orig method can simply
add an attn_mask argument to the function signature. They could either use
the mask in the same way as in the original function (ie, passing it to the
DoubleBlocks and SingleBlocks) or they could ignore it. Either way, it's
just about adding the argument to the function.
|
Thanks, that worked. |
I was able to get input masking working based on the sample code provided. Here is a clumsy custom node that implements the effect: import torch
import torch.nn.functional as F
class StyleModelApplyMasked:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"mask": ("MASK",),
"attn_bias": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.01}),
"max_bias": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, conditioning, style_model, clip_vision_output, mask, attn_bias, max_bias):
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
n = cond.shape[1]
c_out = []
image_mask = F.interpolate(torch.unsqueeze(mask, 1), size=(27,27), mode='bilinear')
image_mask = torch.squeeze(image_mask, 1)
image_mask = torch.log(image_mask.flatten())
# for some reason I get black images if the values are not constrained
image_mask = torch.clamp(image_mask, min=-max_bias, max=-attn_bias)
for t in conditioning:
(txt, keys) = t
keys = keys.copy()
mask_ref_size = (1,1)
n_ref = mask_ref_size[0] * mask_ref_size[1]
n_txt = txt.shape[1]
new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=txt.dtype)
new_mask[:, :n_txt, n_txt:n_txt + n] = image_mask
new_mask[:, n_txt + n:, n_txt:n_txt + n] = image_mask
keys["attention_mask"] = new_mask.to(txt.device)
keys["attention_mask_img_shape"] = mask_ref_size
c_out.append([torch.cat((txt, cond), dim=1), keys])
return (c_out,) @Slickytail Redux tokens seem to be 27x27 patches, not 29x29. The effect seems to not be 100% effective, still investigating. Also I am getting black images in the output if the attention mask contains large negative values for some reason, not sure if there is a bug in this code or somewhere else. |
This PR allows the
DoubleStreamBlock
andSingleStreamBlock
of Flux to accept an attention mask.I also fixed a bug in the xformers attention call that caused an incorrect mask shape when using
skip_reshape=True
, as is the case in flux.This PR provides no direct mechanism for setting attention masks in the flux model (I thought about adding a key to
transformer_options
that sets masks, and maybe that would be how you prefer to do it), but it can be done easily using object patches. As an example of how to do this, and a useful tool enabled by this change, the PR includes a Redux node that uses attention bias to change the weight of the image, rather than multiplying the clip embeds.Let me know if you would rather that I remove this node and keep it as just a custom node. I think having it in the code is useful because it makes it clear how to implement other nodes that set attention masks.