Skip to content

Commit

Permalink
feat: ref cross attention unet
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Sep 13, 2023
1 parent e2893e8 commit 72df92a
Show file tree
Hide file tree
Showing 5 changed files with 588 additions and 37 deletions.
30 changes: 29 additions & 1 deletion models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from .modules.diffusion_generator import DiffusionGenerator
from .modules.resnet_architecture.resnet_generator_diff import ResnetGenerator_attn_diff
from .modules.unet_generator_attn.unet_generator_attn import UNet, UViT
from .modules.unet_generator_attn.unet_generator_attn import (
UNet,
UViT,
UNetGeneratorRefAttn,
)
from .modules.palette_denoise_fn import PaletteDenoiseFn


Expand Down Expand Up @@ -107,6 +111,30 @@ def define_G(
freq_space=train_feat_wavelet,
)

elif G_netG == "unet_mha_ref_attn":
cond_embed_dim = alg_palette_cond_embed_dim

model = UNetGeneratorRefAttn(
image_size=data_crop_size,
in_channel=in_channel,
inner_channel=G_ngf,
out_channel=model_output_nc,
res_blocks=G_unet_mha_res_blocks,
attn_res=G_unet_mha_attn_res,
num_heads=G_unet_mha_num_heads,
num_head_channels=G_unet_mha_num_head_channels,
tanh=False,
dropout=G_dropout,
n_timestep_train=G_diff_n_timestep_train,
n_timestep_test=G_diff_n_timestep_test,
channel_mults=G_unet_mha_channel_mults,
norm=G_unet_mha_norm_layer,
group_norm_size=G_unet_mha_group_norm_size,
efficient=G_unet_mha_vit_efficient,
cond_embed_dim=cond_embed_dim,
freq_space=train_feat_wavelet,
)

elif G_netG == "uvit":
model = UViT(
image_size=data_crop_size,
Expand Down
7 changes: 6 additions & 1 deletion models/modules/palette_denoise_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def forward(self, input, embed_noise_level, cls, mask, ref):
if "mask" in self.conditioning:
input = torch.cat([input, mask_embed], dim=1)

return self.model(input, embedding)
if "Ref" in type(self.model).__name__:
out = self.model(input, embedding, ref)
else:
out = self.model(input, embedding)

return out

def compute_cond(self, input, cls, mask, ref):
if "class" in self.conditioning and cls is not None:
Expand Down
Loading

0 comments on commit 72df92a

Please sign in to comment.