-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MaHiRo (improved/alternate CFG) (#5975)
* Add MaHiRo (improved CFG) long explanation of what it is is [here](https://huggingface.co/spaces/yoinked/blue-arxiv) (2024-1208.1) note: if the node name has encoding issues (utf 8/whatever), id suggest to replace the face at the end with `(>w<)` * add it to nodes.py, add description, and make it a post_cfg function * fix * revert the sampler_cfg_function thing * switch cfg to args["denoised"]
- Loading branch information
Showing
2 changed files
with
42 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
class Mahiro: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": {"model": ("MODEL",), | ||
}} | ||
RETURN_TYPES = ("MODEL",) | ||
RETURN_NAMES = ("patched_model",) | ||
FUNCTION = "patch" | ||
CATEGORY = "_for_testing" | ||
DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt." | ||
def patch(self, model): | ||
m = model.clone() | ||
def mahiro_normd(args): | ||
scale: float = args['cond_scale'] | ||
cond_p: torch.Tensor = args['cond_denoised'] | ||
uncond_p: torch.Tensor = args['uncond_denoised'] | ||
#naive leap | ||
leap = cond_p * scale | ||
#sim with uncond leap | ||
u_leap = uncond_p * scale | ||
cfg = args["denoised"] | ||
merge = (leap + cfg) / 2 | ||
normu = torch.sqrt(u_leap.abs()) * u_leap.sign() | ||
normm = torch.sqrt(merge.abs()) * merge.sign() | ||
sim = F.cosine_similarity(normu, normm).mean() | ||
simsc = 2 * (sim+1) | ||
wm = (simsc*cfg + (4-simsc)*leap) / 4 | ||
return wm | ||
m.set_model_sampler_post_cfg_function(mahiro_normd) | ||
return (m, ) | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
"Mahiro": Mahiro | ||
} | ||
|
||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
"Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5bea1d2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yoinked-h @comfyanonymous I think you need to rename the node... 😅
5bea1d2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this node is really good improve my generation quality by 15%