diff --git a/examples/example_ddpm_unetref_viton.json b/examples/example_ddpm_unetref_viton.json index ed029bfd6..190390d03 100644 --- a/examples/example_ddpm_unetref_viton.json +++ b/examples/example_ddpm_unetref_viton.json @@ -35,7 +35,8 @@ "padding_type": "reflect", "spectral": false, "unet_mha_attn_res": [ - 16 + 4, + 8 ], "unet_mha_channel_mults": [ 1, @@ -48,10 +49,10 @@ "unet_mha_num_head_channels": 32, "unet_mha_num_heads": 1, "unet_mha_res_blocks": [ - 2, - 2, - 2, - 2 + 2, + 4, + 4, + 2 ], "unet_mha_vit_efficient": false, "uvit_num_transformer_blocks": 6 @@ -250,7 +251,7 @@ "cls_regression": false, "compute_D_accuracy": false, "compute_metrics_test": false, - "continue": true, + "continue": false, "epoch": "latest", "epoch_count": 1, "export_jit": false, diff --git a/models/modules/palette_denoise_fn.py b/models/modules/palette_denoise_fn.py index 516c78124..61dedec3b 100644 --- a/models/modules/palette_denoise_fn.py +++ b/models/modules/palette_denoise_fn.py @@ -7,6 +7,8 @@ from .image_bind.imagebind_model import ModalityType import clip +from inspect import signature + class LabelEmbedder(nn.Module): """ @@ -34,6 +36,9 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses) super().__init__() self.model = model + model_sig = signature(model.forward) + self.model_nargs = len(model_sig.parameters) + self.conditioning = conditioning self.cond_embed_dim = cond_embed_dim self.ref_embed_net = ref_embed_net @@ -48,7 +53,6 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses) nn.init.normal_(self.netl_embedder_class.embedding_table.weight, std=0.02) if "mask" in conditioning: - # TODO make a new option cond_embed_mask = cond_embed_dim self.netl_embedder_mask = LabelEmbedder( nclasses, @@ -101,7 +105,7 @@ def forward(self, input, embed_noise_level, cls, mask, ref): if "mask" in self.conditioning: input = torch.cat([input, mask_embed], dim=1) - if "Ref" in type(self.model).__name__: + if self.model_nargs == 4: # ref from dataloader with reference image out = self.model(input, embedding, ref) else: out = self.model(input, embedding) diff --git a/models/modules/unet_generator_attn/unet_generator_attn.py b/models/modules/unet_generator_attn/unet_generator_attn.py index cf3611477..80aa942e2 100644 --- a/models/modules/unet_generator_attn/unet_generator_attn.py +++ b/models/modules/unet_generator_attn/unet_generator_attn.py @@ -1209,6 +1209,8 @@ def __init__( self.num_heads_upsample = num_heads_upsample self.freq_space = freq_space + num_ref_blocks = 0 + if self.freq_space: from ..freq_utils import InverseHaarTransform, HaarTransform @@ -1246,6 +1248,7 @@ def __init__( ] ch = int(mult * self.inner_channel) if ds in attn_res: + num_ref_blocks += 1 layers.append( AttentionBlockRef( ch, @@ -1289,6 +1292,7 @@ def __init__( ds *= 2 self._feature_size += ch + num_ref_blocks += 1 self.middle_block = EmbedSequentialRef( ResBlock( ch, @@ -1490,6 +1494,7 @@ def __init__( ] ch = int(self.inner_channel * mult) if ds in attn_res: + num_ref_blocks += 1 layers.append( AttentionBlockRef( ch, @@ -1555,6 +1560,8 @@ def __init__( }, } + print("Dual U-Net: number of ref blocks: ", num_ref_blocks) + def compute_feats(self, input, embed_gammas, ref=None): if embed_gammas is None: # Only for GAN