diff --git a/models/diffusion_networks.py b/models/diffusion_networks.py index bea5d8b4b..23cfc7872 100644 --- a/models/diffusion_networks.py +++ b/models/diffusion_networks.py @@ -36,6 +36,7 @@ def define_G( alg_palette_sampling_method, alg_palette_conditioning, alg_palette_cond_embed_dim, + alg_palette_ref_embed_net, model_prior_321_backwardcompatibility, dropout=0, channel_mults=(1, 2, 4, 8), @@ -154,6 +155,7 @@ def define_G( denoise_fn = PaletteDenoiseFn( model=model, cond_embed_dim=cond_embed_dim, + ref_embed_net=alg_palette_ref_embed_net, conditioning=alg_palette_conditioning, nclasses=f_s_semantic_nclasses, ) diff --git a/models/modules/palette_denoise_fn.py b/models/modules/palette_denoise_fn.py index d64c3819f..b2b8b3d66 100644 --- a/models/modules/palette_denoise_fn.py +++ b/models/modules/palette_denoise_fn.py @@ -3,9 +3,9 @@ from torchvision import transforms from einops import rearrange - from .image_bind import imagebind_model from .image_bind.imagebind_model import ModalityType +import clip class LabelEmbedder(nn.Module): @@ -30,12 +30,13 @@ def forward(self, labels): class PaletteDenoiseFn(nn.Module): - def __init__(self, model, cond_embed_dim, conditioning, nclasses): + def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses): super().__init__() self.model = model self.conditioning = conditioning self.cond_embed_dim = cond_embed_dim + self.ref_embed_net = ref_embed_net # Label embedding if "class" in conditioning: @@ -59,9 +60,6 @@ def __init__(self, model, cond_embed_dim, conditioning, nclasses): if "ref" in conditioning: cond_embed_class = cond_embed_dim // 2 - """self.freezenetImageBin = imagebind_model.imagebind_huge(pretrained=True) - self.freezenetImageBin.eval()""" - self.ref_transform = transforms.Compose( [ transforms.Resize( @@ -71,15 +69,22 @@ def __init__(self, model, cond_embed_dim, conditioning, nclasses): ] ) - import clip + if ref_embed_net == "clip": + model_name = "ViT-B/16" + self.freezenetClip, _ = clip.load(model_name) + self.freezenetClip = self.freezenetClip.visual.float() + ref_embed_dim = 512 - model_name = "ViT-B/16" - self.freezenetClip, _ = clip.load(model_name) + elif ref_embed_net == "imagebind": + self.freezenetImageBin = imagebind_model.imagebind_huge(pretrained=True) + self.freezenetImageBin.eval() + ref_embed_dim = 1024 - self.freezenetClip = self.freezenetClip.visual.float() + else: + raise NotImplementedError(ref_embed_net) self.emb_layers = nn.Sequential( - torch.nn.SiLU(), nn.Linear(512, cond_embed_class) + torch.nn.SiLU(), nn.Linear(ref_embed_dim, cond_embed_class) ) def forward(self, input, embed_noise_level, cls, mask, ref): @@ -115,12 +120,18 @@ def compute_cond(self, input, cls, mask, ref): mask_embed = None if "ref" in self.conditioning: - """ref = self.bin_transform(ref) - input_ref = {ModalityType.VISION: ref} - ref_embed = self.freezenetImageBin(input_ref)["vision"]""" + ref = self.ref_transform(ref) + + if self.ref_embed_net == "clip": + ref_embed = self.freezenetClip(ref) + + elif self.ref_embed_net == "imagebind": + input_ref = {ModalityType.VISION: ref} + ref_embed = self.freezenetImageBin(input_ref)["vision"] + + else: + raise NotImplementedError(ref_embed_net) - ref_input = self.ref_transform(ref) - ref_embed = self.freezenetClip(ref_input) ref_embed = self.emb_layers(ref_embed) else: diff --git a/models/palette_model.py b/models/palette_model.py index 2fd0f3ae1..dff765119 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -228,6 +228,14 @@ def modify_commandline_options(parser, is_train=True): help="whether to generate samples of each images", ) + parser.add_argument( + "--alg_palette_ref_embed_net", + type=str, + default="clip", + choices=["clip", "imagebind"], + help="embedding network to use for ref conditioning", + ) + return parser def __init__(self, opt, rank):