Skip to content

Commit

Permalink
feat: ref cross attention unet
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Aug 25, 2023
1 parent 4f0500d commit b479fa8
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 34 deletions.
4 changes: 0 additions & 4 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(

if loading_backward_compatibility:
if type(self.denoise_fn.model).__name__ == "ResnetGenerator_attn_diff":

inner_channel = G_ngf
self.cond_embed = nn.Sequential(
nn.Linear(inner_channel, cond_embed_dim),
Expand All @@ -50,7 +49,6 @@ def __init__(
)

elif type(self.denoise_fn.model).__name__ == "UNet":

inner_channel = G_ngf
cond_embed_dim = inner_channel * 4

Expand Down Expand Up @@ -248,7 +246,6 @@ def p_sample(
y_cond=None,
guidance_scale=0.0,
):

model_mean, model_log_variance = self.p_mean_variance(
y_t=y_t,
t=t,
Expand Down Expand Up @@ -423,7 +420,6 @@ def ddim_p_mean_variance(
return model_mean, posterior_log_variance

def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):

b, *_ = y_0.shape
t = torch.randint(
1, self.denoise_fn.model.num_timesteps_train, (b,), device=y_0.device
Expand Down
2 changes: 1 addition & 1 deletion models/modules/palette_denoise_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(self, input, embed_noise_level, cls, mask, ref):
if "mask" in self.conditioning:
input = torch.cat([input, mask_embed], dim=1)

return self.model(input, embedding)
return self.model(input, embedding, ref=ref)

def compute_cond(self, input, cls, mask, ref):
if "class" in self.conditioning and cls is not None:
Expand Down
162 changes: 133 additions & 29 deletions models/modules/unet_generator_attn/unet_generator_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,21 @@ class EmbedSequential(nn.Sequential, EmbedBlock):
support it as an extra input.
"""

def forward(self, x, emb):
def forward(self, x, emb, qkv_ref):
qkv = []
for layer in self:
if isinstance(layer, EmbedBlock):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
if qkv_ref is None or type(qkv_ref) != list:
cur_qkv_ref = qkv_ref
else:
cur_qkv_ref = qkv_ref.pop(0)

x, qkv = layer(x, qkv_ref=cur_qkv_ref)
else:
x = layer(x)
return x
return x, qkv


class Upsample(nn.Module):
Expand Down Expand Up @@ -296,6 +304,7 @@ def __init__(
self.use_transformer = use_transformer
self.norm = normalization1d(channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.qkv_ref = nn.Conv1d(channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
Expand All @@ -305,24 +314,36 @@ def __init__(

self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))

def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
def forward(self, x, qkv_ref=None):
return checkpoint(
self._forward, (x, qkv_ref), self.parameters(), self.use_checkpoint
)

def _forward(self, x):
def _forward(self, x, qkv_ref):
b, c, *spatial = x.shape
if self.use_transformer:
x = x.reshape(b, -1, c)
else:
x = x.reshape(b, c, -1)

qkv = self.qkv(self.norm(x))

if qkv_ref is not None:
q, _, _ = qkv.chunk(3, dim=1)

_, k, v = qkv_ref.chunk(3, dim=1)

qkv = torch.cat([q, k, v], dim=1)

h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
return (x + h).reshape(b, c, *spatial), qkv


class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
A module which p
erforms QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""

def __init__(self, n_heads):
Expand Down Expand Up @@ -574,6 +595,90 @@ def __init__(
)
self._feature_size += ch

### cross attention

self.input_blocks_ref = nn.ModuleList(
[EmbedSequential(nn.Conv2d(in_channel, ch, 3, padding=1))]
)

ds = 1
for level, mult in enumerate(channel_mults):
for _ in range(res_blocks[level]):
layers = [
ResBlock(
ch,
self.cond_embed_dim,
dropout,
out_channel=int(mult * inner_channel),
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
norm=norm,
efficient=efficient,
)
]
ch = int(mult * inner_channel)
if ds in attn_res:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks_ref.append(EmbedSequential(*layers))

if level != len(channel_mults) - 1:
out_ch = ch
self.input_blocks_ref.append(
EmbedSequential(
ResBlock(
ch,
self.cond_embed_dim,
dropout,
out_channel=out_ch,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
norm=norm,
efficient=efficient,
)
if resblock_updown
else Downsample(ch, conv_resample, out_channel=out_ch)
)
)

self.middle_block_ref = EmbedSequential(
ResBlock(
ch,
self.cond_embed_dim,
dropout,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
norm=norm,
efficient=efficient,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
self.cond_embed_dim,
dropout,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
norm=norm,
efficient=efficient,
),
)

###

self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mults))[::-1]:
for i in range(res_blocks[level] + 1):
Expand Down Expand Up @@ -657,14 +762,27 @@ def __init__(
},
}

def compute_feats(self, input, embed_gammas):
def compute_feats(self, input, embed_gammas, ref=None):
if embed_gammas is None:
# Only for GAN
b = (input.shape[0], self.cond_embed_dim)
embed_gammas = torch.ones(b).to(input.device)

emb = embed_gammas

if ref is not None:
ref = torch.cat([ref, ref], dim=1)

qkv_list = []
h = ref.type(torch.float32)
for module in self.input_blocks_ref:
h, qkv_ref = module(h, emb, qkv_ref=None)

qkv_list.append(qkv_ref)

h, qkv_ref = self.middle_block_ref(h, emb, qkv_ref=None)
qkv_list.append(qkv_ref)

hs = []

h = input.type(torch.float32)
Expand All @@ -673,20 +791,21 @@ def compute_feats(self, input, embed_gammas):
h = self.dwt(h)

for module in self.input_blocks:
h, _ = module(h, emb, qkv_ref=qkv_list.pop(0))

h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)

h, _ = self.middle_block(h, emb, qkv_ref=qkv_list.pop(0))

outs, feats = h, hs
return outs, feats, emb

def forward(self, input, embed_gammas=None):
h, hs, emb = self.compute_feats(input, embed_gammas=embed_gammas)
def forward(self, input, embed_gammas=None, ref=None):
h, hs, emb = self.compute_feats(input, embed_gammas=embed_gammas, ref=ref)

for i, module in enumerate(self.output_blocks):
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h, _ = module(h, emb, ref)
h = h.type(input.dtype)
outh = self.out(h)

Expand Down Expand Up @@ -989,6 +1108,7 @@ def forward(self, input, embed_gammas=None):
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(input.dtype)

outh = self.out(h)

if self.freq_space:
Expand All @@ -1010,19 +1130,3 @@ def extract(self, a, t, x_shape=(1, 1, 1, 1)):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))


if __name__ == "__main__":
b, c, h, w = 3, 6, 64, 64
timsteps = 100
model = UNet(
image_size=h,
in_channel=c,
inner_channel=64,
out_channel=3,
res_blocks=[2, 2, 2, 2],
attn_res=[8],
)
x = torch.randn((b, c, h, w))
emb = torch.ones((b,))
out = model(x, emb)

0 comments on commit b479fa8

Please sign in to comment.