diff --git a/models/modules/unet_generator_attn/unet_generator_attn.py b/models/modules/unet_generator_attn/unet_generator_attn.py index 535157d30..4c56b86df 100644 --- a/models/modules/unet_generator_attn/unet_generator_attn.py +++ b/models/modules/unet_generator_attn/unet_generator_attn.py @@ -1032,7 +1032,8 @@ def forward(self, x, emb, qkv_ref=None): else: cur_qkv_ref = qkv_ref.pop(0) - x, qkv = layer(x, qkv_ref=cur_qkv_ref) + x, qkv_out = layer(x, qkv_ref=cur_qkv_ref) + qkv.append(qkv_out) else: x = layer(x) return x, qkv @@ -1054,6 +1055,7 @@ def __init__( use_new_attention_order=False, use_transformer=False, use_ref=False, + terminal=False, ): super().__init__() self.channels = channels @@ -1067,13 +1069,15 @@ def __init__( self.use_checkpoint = use_checkpoint self.use_transformer = use_transformer self.norm = normalization1d(channels) + self.use_ref = use_ref + self.terminal = terminal - if use_ref: - self.qkv = nn.Conv1d(channels, channels * 4, 1) - else: - self.qkv = nn.Conv1d(channels, channels * 3, 1) + #if self.use_ref: + # self.qkv = nn.Conv1d(channels, channels * 4, 1) + #else: + self.qkv = nn.Conv1d(channels, channels * 3, 1) - self.qkv_ref = nn.Conv1d(channels, channels * 2, 1) + #self.qkv_ref = nn.Conv1d(channels, channels * 2, 1) if use_new_attention_order: # split qkv before split heads self.attention = QKVAttention(self.num_heads) @@ -1081,12 +1085,11 @@ def __init__( # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) - if use_ref: - self.proj_out = zero_module(nn.Conv1d(channels * 2, channels, 1)) - else: - self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) - - self.use_ref = use_ref + if not self.terminal: + if self.use_ref: + self.proj_out = zero_module(nn.Conv1d(channels * 2, channels, 1)) + else: + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) def forward(self, x, qkv_ref=None): return checkpoint( @@ -1102,26 +1105,30 @@ def _forward(self, x, qkv_ref): qkv = self.qkv(self.norm(x)) - if self.use_ref: - assert qkv_ref is not None + if not self.terminal: + if self.use_ref: + assert qkv_ref is not None - q, k, v, q_ref = qkv.chunk(4, dim=1) + #q, k, v, q_ref = qkv.chunk(4, dim=1) + q, k, v = qkv.chunk(3, dim=1) - _, k_ref, v_ref = qkv_ref.chunk(3, dim=1) + _, k_ref, v_ref = qkv_ref.chunk(3, dim=1) - qkv = torch.cat([q, k, v], dim=1) + qkv = torch.cat([q, k, v], dim=1) - qkv_ref = torch.cat([q_ref, k_ref, v_ref], dim=1) + qkv_ref = torch.cat([q, k_ref, v_ref], dim=1) - h_ref = self.attention(qkv_ref) + h_ref = self.attention(qkv_ref) - h = self.attention(qkv) + h = self.attention(qkv) - if self.use_ref: - h = torch.cat([h, h_ref], dim=1) + if self.use_ref: + h = torch.cat([h, h_ref], dim=1) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial), qkv + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial), qkv + else: + return None, qkv ##### unet ref @@ -1329,14 +1336,14 @@ def __init__( ch, self.cond_embed_dim, dropout, - out_channel=int(mult * inner_channel), + out_channel=int(mult * self.inner_channel), use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, norm=norm, efficient=efficient, ) ] - ch = int(mult * inner_channel) + ch = int(mult * self.inner_channel) if ds in attn_res: layers.append( AttentionBlockRef( @@ -1368,6 +1375,8 @@ def __init__( else Downsample(ch, conv_resample, out_channel=out_ch) ) ) + ch = out_ch + ds *= 2 self.middle_block_ref = EmbedSequentialRef( ResBlock( @@ -1385,16 +1394,17 @@ def __init__( num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, + terminal = True, ), - ResBlock( - ch, - self.cond_embed_dim, - dropout, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - norm=norm, - efficient=efficient, - ), +# ResBlock( +# ch, +# self.cond_embed_dim, +# dropout, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# norm=norm, +# efficient=efficient, +# ), ) ### @@ -1497,10 +1507,9 @@ def compute_feats(self, input, embed_gammas, ref=None): h = ref.type(torch.float32) for module in self.input_blocks_ref: h, qkv_ref = module(h, emb, qkv_ref=None) - qkv_list.append(qkv_ref) - h, qkv_ref = self.middle_block_ref(h, emb, qkv_ref=None) + _, qkv_ref = self.middle_block_ref(h, emb, qkv_ref=None) qkv_list.append(qkv_ref) hs = [] @@ -1512,10 +1521,10 @@ def compute_feats(self, input, embed_gammas, ref=None): for module in self.input_blocks: h, _ = module(h, emb, qkv_ref=qkv_list.pop(0)) - hs.append(h) - h, _ = self.middle_block(h, emb, qkv_ref=qkv_list.pop(0)) + qkv_ref = qkv_list.pop(0) + h, _ = self.middle_block(h, emb, qkv_ref=qkv_ref)#qkv_list.pop(0)) outs, feats = h, hs return outs, feats, emb