Skip to content

Commit

Permalink
Fix up loop-binding issues in ImageTransformerV2
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Feb 13, 2024
1 parent c3f6608 commit 36dc02d
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions k_diffusion/models/image_transformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,18 @@ class MappingSpec:
dropout: float


def make_layer_factory(spec, mapping):
if isinstance(spec.self_attn, GlobalAttentionSpec):
return lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout)
elif isinstance(spec.self_attn, NeighborhoodAttentionSpec):
return lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout)
elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec):
return lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout)
elif isinstance(spec.self_attn, NoAttentionSpec):
return lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout)
raise ValueError(f"unsupported self attention spec {spec.self_attn}")


# Model class

class ImageTransformerDenoiserModelV2(nn.Module):
Expand All @@ -662,16 +674,7 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c

self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList()
for i, spec in enumerate(levels):
if isinstance(spec.self_attn, GlobalAttentionSpec):
layer_factory = lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout)
elif isinstance(spec.self_attn, NeighborhoodAttentionSpec):
layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout)
elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec):
layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout)
elif isinstance(spec.self_attn, NoAttentionSpec):
layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout)
else:
raise ValueError(f"unsupported self attention spec {spec.self_attn}")
layer_factory = self.make_layer_factory(spec, mapping)

if i < len(levels) - 1:
self.down_levels.append(Level([layer_factory(i) for i in range(spec.depth)]))
Expand Down

0 comments on commit 36dc02d

Please sign in to comment.