Skip to content

Commit

Permalink
Add MaHiRo (improved/alternate CFG) (#5975)
Browse files Browse the repository at this point in the history
* Add MaHiRo (improved CFG)

long explanation of what it is is [here](https://huggingface.co/spaces/yoinked/blue-arxiv) (2024-1208.1) 


note: if the node name has encoding issues (utf 8/whatever), id suggest to replace the face at the end with `(>w<)`

* add it to nodes.py, add description, and make it a post_cfg function

* fix

* revert the sampler_cfg_function thing

* switch cfg to args["denoised"]
  • Loading branch information
yoinked-h authored Dec 11, 2024
1 parent 5def9fb commit 5bea1d2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
41 changes: 41 additions & 0 deletions comfy_extras/nodes_mahiro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn.functional as F

class Mahiro:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt."
def patch(self, model):
m = model.clone()
def mahiro_normd(args):
scale: float = args['cond_scale']
cond_p: torch.Tensor = args['cond_denoised']
uncond_p: torch.Tensor = args['uncond_denoised']
#naive leap
leap = cond_p * scale
#sim with uncond leap
u_leap = uncond_p * scale
cfg = args["denoised"]
merge = (leap + cfg) / 2
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1)
wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd)
return (m, )

NODE_CLASS_MAPPINGS = {
"Mahiro": Mahiro
}

NODE_DISPLAY_NAME_MAPPINGS = {
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,7 @@ def init_builtin_extra_nodes():
"nodes_torch_compile.py",
"nodes_mochi.py",
"nodes_slg.py",
"nodes_mahiro.py",
"nodes_lt.py",
"nodes_hooks.py",
]
Expand Down

2 comments on commit 5bea1d2

@LukeG89
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yoinked-h @comfyanonymous I think you need to rename the node... 😅

search

node

@brahianrosswill
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this node is really good improve my generation quality by 15%

Please sign in to comment.