Skip to content

Commit

Permalink
feat(ml): dual-UNet at encoder and middle blblocks stages
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 28, 2023
1 parent 8f86c72 commit 1ba07ee
Showing 1 changed file with 48 additions and 39 deletions.
87 changes: 48 additions & 39 deletions models/modules/unet_generator_attn/unet_generator_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,8 @@ def forward(self, x, emb, qkv_ref=None):
else:
cur_qkv_ref = qkv_ref.pop(0)

x, qkv = layer(x, qkv_ref=cur_qkv_ref)
x, qkv_out = layer(x, qkv_ref=cur_qkv_ref)
qkv.append(qkv_out)
else:
x = layer(x)
return x, qkv
Expand All @@ -1054,6 +1055,7 @@ def __init__(
use_new_attention_order=False,
use_transformer=False,
use_ref=False,
terminal=False,
):
super().__init__()
self.channels = channels
Expand All @@ -1067,26 +1069,27 @@ def __init__(
self.use_checkpoint = use_checkpoint
self.use_transformer = use_transformer
self.norm = normalization1d(channels)
self.use_ref = use_ref
self.terminal = terminal

if use_ref:
self.qkv = nn.Conv1d(channels, channels * 4, 1)
else:
self.qkv = nn.Conv1d(channels, channels * 3, 1)
#if self.use_ref:
# self.qkv = nn.Conv1d(channels, channels * 4, 1)
#else:
self.qkv = nn.Conv1d(channels, channels * 3, 1)

self.qkv_ref = nn.Conv1d(channels, channels * 2, 1)
#self.qkv_ref = nn.Conv1d(channels, channels * 2, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)

if use_ref:
self.proj_out = zero_module(nn.Conv1d(channels * 2, channels, 1))
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))

self.use_ref = use_ref
if not self.terminal:
if self.use_ref:
self.proj_out = zero_module(nn.Conv1d(channels * 2, channels, 1))
else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))

def forward(self, x, qkv_ref=None):
return checkpoint(
Expand All @@ -1102,26 +1105,30 @@ def _forward(self, x, qkv_ref):

qkv = self.qkv(self.norm(x))

if self.use_ref:
assert qkv_ref is not None
if not self.terminal:
if self.use_ref:
assert qkv_ref is not None

q, k, v, q_ref = qkv.chunk(4, dim=1)
#q, k, v, q_ref = qkv.chunk(4, dim=1)
q, k, v = qkv.chunk(3, dim=1)

_, k_ref, v_ref = qkv_ref.chunk(3, dim=1)
_, k_ref, v_ref = qkv_ref.chunk(3, dim=1)

qkv = torch.cat([q, k, v], dim=1)
qkv = torch.cat([q, k, v], dim=1)

qkv_ref = torch.cat([q_ref, k_ref, v_ref], dim=1)
qkv_ref = torch.cat([q, k_ref, v_ref], dim=1)

h_ref = self.attention(qkv_ref)
h_ref = self.attention(qkv_ref)

h = self.attention(qkv)
h = self.attention(qkv)

if self.use_ref:
h = torch.cat([h, h_ref], dim=1)
if self.use_ref:
h = torch.cat([h, h_ref], dim=1)

h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial), qkv
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial), qkv
else:
return None, qkv


##### unet ref
Expand Down Expand Up @@ -1329,14 +1336,14 @@ def __init__(
ch,
self.cond_embed_dim,
dropout,
out_channel=int(mult * inner_channel),
out_channel=int(mult * self.inner_channel),
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
norm=norm,
efficient=efficient,
)
]
ch = int(mult * inner_channel)
ch = int(mult * self.inner_channel)
if ds in attn_res:
layers.append(
AttentionBlockRef(
Expand Down Expand Up @@ -1368,6 +1375,8 @@ def __init__(
else Downsample(ch, conv_resample, out_channel=out_ch)
)
)
ch = out_ch
ds *= 2

self.middle_block_ref = EmbedSequentialRef(
ResBlock(
Expand All @@ -1385,16 +1394,17 @@ def __init__(
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
terminal = True,
),
ResBlock(
ch,
self.cond_embed_dim,
dropout,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
norm=norm,
efficient=efficient,
),
# ResBlock(
# ch,
# self.cond_embed_dim,
# dropout,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# norm=norm,
# efficient=efficient,
# ),
)

###
Expand Down Expand Up @@ -1497,10 +1507,9 @@ def compute_feats(self, input, embed_gammas, ref=None):
h = ref.type(torch.float32)
for module in self.input_blocks_ref:
h, qkv_ref = module(h, emb, qkv_ref=None)

qkv_list.append(qkv_ref)

h, qkv_ref = self.middle_block_ref(h, emb, qkv_ref=None)
_, qkv_ref = self.middle_block_ref(h, emb, qkv_ref=None)
qkv_list.append(qkv_ref)

hs = []
Expand All @@ -1512,10 +1521,10 @@ def compute_feats(self, input, embed_gammas, ref=None):

for module in self.input_blocks:
h, _ = module(h, emb, qkv_ref=qkv_list.pop(0))

hs.append(h)

h, _ = self.middle_block(h, emb, qkv_ref=qkv_list.pop(0))
qkv_ref = qkv_list.pop(0)
h, _ = self.middle_block(h, emb, qkv_ref=qkv_ref)#qkv_list.pop(0))

outs, feats = h, hs
return outs, feats, emb
Expand Down

0 comments on commit 1ba07ee

Please sign in to comment.