Skip to content

Commit

Permalink
Add some latent operation nodes.
Browse files Browse the repository at this point in the history
This is a port of the ModelSamplerTonemapNoiseTest from the experiments
repo.

To replicate that node use LatentOperationTonemapReinhard and
LatentApplyOperationCFG together.
  • Loading branch information
comfyanonymous committed Oct 15, 2024
1 parent f584758 commit 0dbba9f
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions comfy_extras/nodes_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,90 @@ def op(self, samples, seed_behavior):

return (samples_out,)

class LatentApplyOperation:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"operation": ("LATENT_OPERATION",),
}}

RETURN_TYPES = ("LATENT",)
FUNCTION = "op"

CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True

def op(self, samples, operation):
samples_out = samples.copy()

s1 = samples["samples"]
samples_out["samples"] = operation(latent=s1)
return (samples_out,)

class LatentApplyOperationCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"operation": ("LATENT_OPERATION",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True

def patch(self, model, operation):
m = model.clone()

def pre_cfg_function(args):
conds_out = args["conds_out"]
if len(conds_out) == 2:
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
else:
conds_out[0] = operation(latent=conds_out[0])
return conds_out

m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m, )

class LatentOperationTonemapReinhard:
@classmethod
def INPUT_TYPES(s):
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}

RETURN_TYPES = ("LATENT_OPERATION",)
FUNCTION = "op"

CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True

def op(self, multiplier):
def tonemap_reinhard(latent, **kwargs):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude

mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)

top = (std * 5 + mean) * multiplier

#reinhard
latent_vector_magnitude *= (1.0 / top)
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
new_magnitude *= top

return normalized_latent * new_magnitude
return (tonemap_reinhard,)

NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,
"LatentApplyOperationCFG": LatentApplyOperationCFG,
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
}

0 comments on commit 0dbba9f

Please sign in to comment.