Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for attention masking in Flux #5942

Merged
merged 17 commits into from
Dec 16, 2024

Conversation

Slickytail
Copy link
Contributor

This PR allows the DoubleStreamBlock and SingleStreamBlock of Flux to accept an attention mask.
I also fixed a bug in the xformers attention call that caused an incorrect mask shape when using skip_reshape=True, as is the case in flux.

This PR provides no direct mechanism for setting attention masks in the flux model (I thought about adding a key to transformer_options that sets masks, and maybe that would be how you prefer to do it), but it can be done easily using object patches. As an example of how to do this, and a useful tool enabled by this change, the PR includes a Redux node that uses attention bias to change the weight of the image, rather than multiplying the clip embeds.
Let me know if you would rather that I remove this node and keep it as just a custom node. I think having it in the code is useful because it makes it clear how to implement other nodes that set attention masks.

@comfyanonymous
Copy link
Owner

This is great, what do you think about adding attention masking support to the flux code in ComfyUI directly?

For other models the attention mask/bias is passed as part of the "conditioning" object: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_base.py#L796 , this also means it could be added as an extra option to the existing apply style model node instead of needing a new node.

@Slickytail
Copy link
Contributor Author

Glad to take the opportunity to make it more flexible, and I'm confident that I understand the conditioning code well enough.
I'm unsure, however, what to do about the fact that the image size is unknown when the conditioning is created. The only approach that I can think of that's sufficiently flexible (ie, allows an arbitrary and full attention mask to be specified) is for the attention mask conditioning to contain the mask AND a "reference shape" for the image (probably 32x32 or something), which indicates how the (flattened) image-token part of the mask should be interpeted -- then, once the model is called and the shape of the image is known, the model could extract the txt-img, img-img, and img-txt blocks of the attention mask, reshape them to a spatial map based on the reference shape, upscale the map, and reflatten, and reassemble the pieces of the mask.

@comfyanonymous do you think this is too convoluted? If not, I'll go ahead and implement it. Another option would be to do like xformers does, and add some simple composable types representing block diagonal, causal, etc. masks for common mask forms with potentially unknown shapes, which could then be "realized" into tensors once the image size is known.

@comfyanonymous
Copy link
Owner

The conditioning can contain the attention mask for only the conditioning itself. Then here https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_base.py#L771 you can process the attention mask.

Here's an example how the "attention_mask" in the conditioning is passed as the "text_embedding_mask" argument of the hunyuan-dit model: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_base.py#L702

If you need the shape of the latent: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_base.py#L663

You can also do it directly in the model code.

@Slickytail
Copy link
Contributor Author

Yes, I see now how the extra_conds function can modify the attention mask.

The problem is that most of the applications for attention masking involve modifying attention between certain parts of the image and certain parts of the text (for example, my use case is to make the image pay less attention to the redux tokens, but I expect that most people will use this for efficient regional prompting), so the conditioning can't only contain the conditioning attention mask. The full attention mask for flux (and other mmdits) has shape (n_tokens_txt + n_tokens_img, n_tokens_txt + n_tokens_img), and the part of the mask corresponding to the image is a flattened version of the 2d patches, so it can't be interpolated to the right size without knowing the original shape.

In any case, I've gone for the implementation I mentioned where the mask is created with a "placeholder" image size, and then interpolated by the extra_conds function. I've also moved the redux weighting code to the normal style application node.

Let me know what you think of this architecture -- I know it's a little convoluted.

This corrects a weird inconsistency with skip_reshape.
It also allows masks of various shapes to be passed, which will be
automtically expanded (in a memory-efficient way) to a size that is
compatible with xformers or pytorch sdpa respectively.
@comfyanonymous
Copy link
Owner

I get an error when running ltx video with xformers, the first workflow on this page: https://comfyanonymous.github.io/ComfyUI_examples/ltxv/

@Slickytail
Copy link
Contributor Author

Fixed. I think the modified pytorch/xformers attention functions should now be compatible with all reasonable mask shapes ([nq, nk], [1|b, nq, nk], [1|b, 1|h, nq, nk]).

@comfyanonymous comfyanonymous merged commit 61b5072 into comfyanonymous:master Dec 16, 2024
5 checks passed
@recris
Copy link

recris commented Dec 17, 2024

I have some questions on how to use this.
Is this about masking regions of the output to control how the conditioning will affect parts of the image, or is it for "ignoring" parts of the conditioning, regardless of the output structure (lets say I am applying Redux conditioning using an image with 2 objects, but I want to filter out one of them)? Do you have any examples (or code) on how to use this?

Some weeks ago I tried to create a version of the Apply Style Model node that took an extra mask image as input. The node used the image to approximate an attention mask for the DiT patches, then feed it into Flux model... it kind of worked, but with some funny results (like if painted a region with the mask it resulted in a "hole" in the output, an effect similar to a dumb in-painting algorithm over the mask hole). Maybe I was doing something wrong in the code.

@Slickytail
Copy link
Contributor Author

Hello @recris,

The advantage of this approach is that you can set the attention weight arbitrarily between any image token and any text token.
If you want to redux an image that has two objects and ignore one of them, you can do that.
If you want the generated image to pay less attention to a certain word in the input, you can do that.
If you want a certain part of the generated image to pay less attention to a certain part of the redux image, you can do that.

I'll give here an example of how to mask out part of the redux image. The effect would be that the entire generated image would be unable to see a specific part of the redux image. I would however note that some information from the masked region of the redux image will still leak out, since the tokens going into the redux model are encoded siglip tokens, which necessarily contain global information as well as local.
For simplicity, we'll overwrite the existing mask, instead of expanding/adding to it.

# cond here is the output from redux
cond = ...
n = cond.shape[1] # = 29*29
c_out = []
# replace this with your mask.
# it should stretched to 29x29, since siglip outputs 29x29 patches.
image_mask = torch.ones((29, 29))
# false -> -inf, true -> 0
image_mask = torch.log(image_mask.flatten())
for t in conditioning:
    (txt, keys) = t
    keys = keys.copy()
    # -inf
    # since we dont care about masking specific parts of the *generated* image, we can just use mask_ref_size = (1, 1)
    mask_ref_size = (1, 1)
    n_ref = mask_ref_size[0] * mask_ref_size[1]
    n_txt = txt.shape[1]
    # create a new mask for the attention calls inside flux
    mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
    # now fill in the attention bias to our redux tokens
    # make the text tokens pay less attention to the specified redux tokens
    mask[:, :n_txt, n_txt:n_txt+n] = image_mask
    # and then make the image tokens pay less attention to the specified redux tokens
    mask[:, n_txt+n:, n_txt:n_txt+n] = image_mask
    keys["attention_mask"] = new_mask.to(txt.device)
    keys["attention_mask_img_shape"] = mask_ref_size

    c_out.append([torch.cat((txt, cond), dim=1), keys])

return (c_out, )

@EnragedAntelope
Copy link

FYI regarding the issue mentioned at gseth/ControlAltAI-Nodes#13 ,
these flux attention changes seem to have broken those nodes with the error included in that issue.
similarly I believe pulid for flux has been negatively affected.
is there easy guidance for node creators to incorporate a change, or a way for them to disable/override the recent native comfy attention changes for their nodes to continue working?

@Slickytail
Copy link
Contributor Author

Slickytail commented Dec 20, 2024 via email

@gseth
Copy link

gseth commented Dec 20, 2024

Yes, custom node creators who override the forward_orig method can simply add an attn_mask argument to the function signature. They could either use the mask in the same way as in the original function (ie, passing it to the DoubleBlocks and SingleBlocks) or they could ignore it. Either way, it's just about adding the argument to the function.

On Fri, Dec 20, 2024, 15:27 EnragedAntelope @.> wrote: FYI regarding the issue mentioned at gseth/ControlAltAI-Nodes#13 <gseth/ControlAltAI-Nodes#13> , these flux attention changes seem to have broken those nodes with the error included in that issue. similarly I believe pulid for flux has been negatively affected. is there easy guidance for node creators to incorporate a change, or a way for them to disable/override the recent native comfy attention changes for their nodes to continue working? — Reply to this email directly, view it on GitHub <#5942 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACDXG7ITIVP47PKXL3HYRCT2GQSMBAVCNFSM6AAAAABTERH35GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKNJXGEYTONRZGY . You are receiving this because you authored the thread.Message ID: @.>

Thanks, that worked.

@recris
Copy link

recris commented Dec 28, 2024

I was able to get input masking working based on the sample code provided.

Here is a clumsy custom node that implements the effect:

import torch
import torch.nn.functional as F


class StyleModelApplyMasked:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"conditioning": ("CONDITIONING", ),
                             "style_model": ("STYLE_MODEL", ),
                             "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
                             "mask": ("MASK",),
                             "attn_bias": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.01}),
                             "max_bias": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
                             }}
    RETURN_TYPES = ("CONDITIONING",)
    FUNCTION = "apply_stylemodel"

    CATEGORY = "conditioning/style_model"

    def apply_stylemodel(self, conditioning, style_model, clip_vision_output, mask, attn_bias, max_bias):
        cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)

        n = cond.shape[1]
        c_out = []

        image_mask = F.interpolate(torch.unsqueeze(mask, 1), size=(27,27), mode='bilinear')
        image_mask = torch.squeeze(image_mask, 1)
        image_mask = torch.log(image_mask.flatten())

        # for some reason I get black images if the values are not constrained
        image_mask = torch.clamp(image_mask, min=-max_bias, max=-attn_bias)

        for t in conditioning:
            (txt, keys) = t
            keys = keys.copy()

            mask_ref_size = (1,1)
            n_ref = mask_ref_size[0] * mask_ref_size[1]
            n_txt = txt.shape[1]
            new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=txt.dtype)

            new_mask[:, :n_txt, n_txt:n_txt + n] = image_mask
            new_mask[:, n_txt + n:, n_txt:n_txt + n] = image_mask

            keys["attention_mask"] = new_mask.to(txt.device)
            keys["attention_mask_img_shape"] = mask_ref_size

            c_out.append([torch.cat((txt, cond), dim=1), keys])

        return (c_out,)

@Slickytail Redux tokens seem to be 27x27 patches, not 29x29. The effect seems to not be 100% effective, still investigating.

Also I am getting black images in the output if the attention mask contains large negative values for some reason, not sure if there is a bug in this code or somewhere else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants