Skip to content

Commit

Permalink
Add LatentInterpolate to interpolate between latents.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 20, 2023
1 parent dba4f3b commit 31c5ea7
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions comfy_extras/nodes_latent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import comfy.utils
import torch

def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
Expand Down Expand Up @@ -67,8 +68,43 @@ def op(self, samples, multiplier):
samples_out["samples"] = s1 * multiplier
return (samples_out,)

class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}

RETURN_TYPES = ("LATENT",)
FUNCTION = "op"

CATEGORY = "latent/advanced"

def op(self, samples1, samples2, ratio):
samples_out = samples1.copy()

s1 = samples1["samples"]
s2 = samples2["samples"]

s2 = reshape_latent_to(s1.shape, s2)

m1 = torch.linalg.vector_norm(s1, dim=(1))
m2 = torch.linalg.vector_norm(s2, dim=(1))

s1 = torch.nan_to_num(s1 / m1)
s2 = torch.nan_to_num(s2 / m2)

t = (s1 * ratio + s2 * (1.0 - ratio))
mt = torch.linalg.vector_norm(t, dim=(1))
st = torch.nan_to_num(t / mt)

samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)

NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
}

0 comments on commit 31c5ea7

Please sign in to comment.