Skip to content

Commit

Permalink
feat: option to select embedding network for ref conditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
royale committed Aug 25, 2023
1 parent 720c224 commit 8375c2b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
2 changes: 2 additions & 0 deletions models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def define_G(
alg_palette_sampling_method,
alg_palette_conditioning,
alg_palette_cond_embed_dim,
alg_palette_ref_embed_net,
model_prior_321_backwardcompatibility,
dropout=0,
channel_mults=(1, 2, 4, 8),
Expand Down Expand Up @@ -154,6 +155,7 @@ def define_G(
denoise_fn = PaletteDenoiseFn(
model=model,
cond_embed_dim=cond_embed_dim,
ref_embed_net=alg_palette_ref_embed_net,
conditioning=alg_palette_conditioning,
nclasses=f_s_semantic_nclasses,
)
Expand Down
41 changes: 26 additions & 15 deletions models/modules/palette_denoise_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from torchvision import transforms
from einops import rearrange


from .image_bind import imagebind_model
from .image_bind.imagebind_model import ModalityType
import clip


class LabelEmbedder(nn.Module):
Expand All @@ -30,12 +30,13 @@ def forward(self, labels):


class PaletteDenoiseFn(nn.Module):
def __init__(self, model, cond_embed_dim, conditioning, nclasses):
def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses):
super().__init__()

self.model = model
self.conditioning = conditioning
self.cond_embed_dim = cond_embed_dim
self.ref_embed_net = ref_embed_net

# Label embedding
if "class" in conditioning:
Expand All @@ -59,9 +60,6 @@ def __init__(self, model, cond_embed_dim, conditioning, nclasses):
if "ref" in conditioning:
cond_embed_class = cond_embed_dim // 2

"""self.freezenetImageBin = imagebind_model.imagebind_huge(pretrained=True)
self.freezenetImageBin.eval()"""

self.ref_transform = transforms.Compose(
[
transforms.Resize(
Expand All @@ -71,15 +69,22 @@ def __init__(self, model, cond_embed_dim, conditioning, nclasses):
]
)

import clip
if ref_embed_net == "clip":
model_name = "ViT-B/16"
self.freezenetClip, _ = clip.load(model_name)
self.freezenetClip = self.freezenetClip.visual.float()
ref_embed_dim = 512

model_name = "ViT-B/16"
self.freezenetClip, _ = clip.load(model_name)
elif ref_embed_net == "imagebind":
self.freezenetImageBin = imagebind_model.imagebind_huge(pretrained=True)
self.freezenetImageBin.eval()
ref_embed_dim = 1024

self.freezenetClip = self.freezenetClip.visual.float()
else:
raise NotImplementedError(ref_embed_net)

self.emb_layers = nn.Sequential(
torch.nn.SiLU(), nn.Linear(512, cond_embed_class)
torch.nn.SiLU(), nn.Linear(ref_embed_dim, cond_embed_class)
)

def forward(self, input, embed_noise_level, cls, mask, ref):
Expand Down Expand Up @@ -115,12 +120,18 @@ def compute_cond(self, input, cls, mask, ref):
mask_embed = None

if "ref" in self.conditioning:
"""ref = self.bin_transform(ref)
input_ref = {ModalityType.VISION: ref}
ref_embed = self.freezenetImageBin(input_ref)["vision"]"""
ref = self.ref_transform(ref)

if self.ref_embed_net == "clip":
ref_embed = self.freezenetClip(ref)

elif self.ref_embed_net == "imagebind":
input_ref = {ModalityType.VISION: ref}
ref_embed = self.freezenetImageBin(input_ref)["vision"]

else:
raise NotImplementedError(ref_embed_net)

ref_input = self.ref_transform(ref)
ref_embed = self.freezenetClip(ref_input)
ref_embed = self.emb_layers(ref_embed)

else:
Expand Down
8 changes: 8 additions & 0 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def modify_commandline_options(parser, is_train=True):
help="whether to generate samples of each images",
)

parser.add_argument(
"--alg_palette_ref_embed_net",
type=str,
default="clip",
choices=["clip", "imagebind"],
help="embedding network to use for ref conditioning",
)

return parser

def __init__(self, opt, rank):
Expand Down

0 comments on commit 8375c2b

Please sign in to comment.