Skip to content

Commit

Permalink
chore: print number of dual U-Net blocks with cross reference
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Oct 5, 2023
1 parent 9075c7e commit 58d6be9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
13 changes: 7 additions & 6 deletions examples/example_ddpm_unetref_viton.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
"padding_type": "reflect",
"spectral": false,
"unet_mha_attn_res": [
16
4,
8
],
"unet_mha_channel_mults": [
1,
Expand All @@ -48,10 +49,10 @@
"unet_mha_num_head_channels": 32,
"unet_mha_num_heads": 1,
"unet_mha_res_blocks": [
2,
2,
2,
2
2,
4,
4,
2
],
"unet_mha_vit_efficient": false,
"uvit_num_transformer_blocks": 6
Expand Down Expand Up @@ -250,7 +251,7 @@
"cls_regression": false,
"compute_D_accuracy": false,
"compute_metrics_test": false,
"continue": true,
"continue": false,
"epoch": "latest",
"epoch_count": 1,
"export_jit": false,
Expand Down
8 changes: 6 additions & 2 deletions models/modules/palette_denoise_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .image_bind.imagebind_model import ModalityType
import clip

from inspect import signature


class LabelEmbedder(nn.Module):
"""
Expand Down Expand Up @@ -34,6 +36,9 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses)
super().__init__()

self.model = model
model_sig = signature(model.forward)
self.model_nargs = len(model_sig.parameters)

self.conditioning = conditioning
self.cond_embed_dim = cond_embed_dim
self.ref_embed_net = ref_embed_net
Expand All @@ -48,7 +53,6 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses)
nn.init.normal_(self.netl_embedder_class.embedding_table.weight, std=0.02)

if "mask" in conditioning:
# TODO make a new option
cond_embed_mask = cond_embed_dim
self.netl_embedder_mask = LabelEmbedder(
nclasses,
Expand Down Expand Up @@ -101,7 +105,7 @@ def forward(self, input, embed_noise_level, cls, mask, ref):
if "mask" in self.conditioning:
input = torch.cat([input, mask_embed], dim=1)

if "Ref" in type(self.model).__name__:
if self.model_nargs == 4: # ref from dataloader with reference image
out = self.model(input, embedding, ref)
else:
out = self.model(input, embedding)
Expand Down
7 changes: 7 additions & 0 deletions models/modules/unet_generator_attn/unet_generator_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,8 @@ def __init__(
self.num_heads_upsample = num_heads_upsample
self.freq_space = freq_space

num_ref_blocks = 0

if self.freq_space:
from ..freq_utils import InverseHaarTransform, HaarTransform

Expand Down Expand Up @@ -1246,6 +1248,7 @@ def __init__(
]
ch = int(mult * self.inner_channel)
if ds in attn_res:
num_ref_blocks += 1
layers.append(
AttentionBlockRef(
ch,
Expand Down Expand Up @@ -1289,6 +1292,7 @@ def __init__(
ds *= 2
self._feature_size += ch

num_ref_blocks += 1
self.middle_block = EmbedSequentialRef(
ResBlock(
ch,
Expand Down Expand Up @@ -1490,6 +1494,7 @@ def __init__(
]
ch = int(self.inner_channel * mult)
if ds in attn_res:
num_ref_blocks += 1
layers.append(
AttentionBlockRef(
ch,
Expand Down Expand Up @@ -1555,6 +1560,8 @@ def __init__(
},
}

print("Dual U-Net: number of ref blocks: ", num_ref_blocks)

def compute_feats(self, input, embed_gammas, ref=None):
if embed_gammas is None:
# Only for GAN
Expand Down

0 comments on commit 58d6be9

Please sign in to comment.