Skip to content

Commit

Permalink
[Chore] remove class assignments for linear and conv. (huggingface#7553)
Browse files Browse the repository at this point in the history
* remove class assignments for linear and conv.

* fix: self.nn
  • Loading branch information
sayakpaul authored Apr 2, 2024
1 parent 5d83f50 commit 000fa82
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 61 deletions.
3 changes: 1 addition & 2 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ def __init__(
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = nn.Linear

if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
Expand All @@ -651,7 +650,7 @@ def __init__(
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
Expand Down
19 changes: 8 additions & 11 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,25 +181,22 @@ def __init__(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)

linear_cls = nn.Linear

self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)

if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None

if self.added_kv_proj_dim is not None:
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
Expand Down Expand Up @@ -706,7 +703,7 @@ def fuse_projections(self, fuse=True):
out_features = concatenated_weights.shape[0]

# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
Expand All @@ -717,7 +714,7 @@ def fuse_projections(self, fuse=True):
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d

if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
Expand All @@ -114,7 +113,7 @@ def __init__(
raise ValueError(f"unknown norm_type: {norm_type}")

if use_conv:
conv = conv_cls(
conv = nn.Conv2d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:
Expand Down
5 changes: 2 additions & 3 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ def __init__(
sample_proj_bias=True,
):
super().__init__()
linear_cls = nn.Linear

self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)

if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
Expand All @@ -214,7 +213,7 @@ def __init__(
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)

if post_act_fn is None:
self.post_act = None
Expand Down
21 changes: 8 additions & 13 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def __init__(
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm

conv_cls = nn.Conv2d

if groups_out is None:
groups_out = groups

Expand All @@ -113,7 +111,7 @@ def __init__(
else:
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")

self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

if self.time_embedding_norm == "ada_group": # ada_group
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
Expand All @@ -125,7 +123,7 @@ def __init__(
self.dropout = torch.nn.Dropout(dropout)

conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)

self.nonlinearity = get_activation(non_linearity)

Expand All @@ -139,7 +137,7 @@ def __init__(

self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
self.conv_shortcut = nn.Conv2d(
in_channels,
conv_2d_out_channels,
kernel_size=1,
Expand Down Expand Up @@ -263,21 +261,18 @@ def __init__(
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act

linear_cls = nn.Linear
conv_cls = nn.Conv2d

if groups_out is None:
groups_out = groups

self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = linear_cls(temb_channels, out_channels)
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
else:
Expand All @@ -287,7 +282,7 @@ def __init__(

self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)

self.nonlinearity = get_activation(non_linearity)

Expand All @@ -313,7 +308,7 @@ def __init__(

self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
self.conv_shortcut = nn.Conv2d(
in_channels,
conv_2d_out_channels,
kernel_size=1,
Expand Down
11 changes: 4 additions & 7 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,6 @@ def __init__(
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim

conv_cls = nn.Conv2d
linear_cls = nn.Linear

# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
Expand Down Expand Up @@ -159,9 +156,9 @@ def __init__(

self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = linear_cls(in_channels, inner_dim)
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
Expand Down Expand Up @@ -222,9 +219,9 @@ def __init__(
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = linear_cls(inner_dim, in_channels)
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
Expand Down
9 changes: 4 additions & 5 deletions src/diffusers/models/unets/unet_stable_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def forward(self, x):
class SDCascadeTimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=[]):
super().__init__()
linear_cls = nn.Linear
self.mapper = linear_cls(c_timestep, c * 2)

self.mapper = nn.Linear(c_timestep, c * 2)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2))

def forward(self, x, t):
t = t.chunk(len(self.conds) + 1, dim=1)
Expand Down Expand Up @@ -94,12 +94,11 @@ def forward(self, x):
class SDCascadeAttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
linear_cls = nn.Linear

self.self_attn = self_attn
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))

def forward(self, x, kv):
kv = self.kv_mapper(kv)
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv2d

if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
Expand All @@ -131,7 +130,7 @@ def __init__(
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
Expand Down
15 changes: 5 additions & 10 deletions src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def forward(self, x):
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
linear_cls = nn.Linear
self.mapper = linear_cls(c_timestep, c * 2)

self.mapper = nn.Linear(c_timestep, c * 2)

def forward(self, x, t):
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
Expand All @@ -29,13 +29,10 @@ class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()

conv_cls = nn.Conv2d
linear_cls = nn.Linear

self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
)

def forward(self, x, x_skip=None):
Expand Down Expand Up @@ -64,12 +61,10 @@ class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()

linear_cls = nn.Linear

self.self_attn = self_attn
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))

def forward(self, x, kv):
kv = self.kv_mapper(kv)
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__()
conv_cls = nn.Conv2d
linear_cls = nn.Linear

self.c_r = c_r
self.projection = conv_cls(c_in, c, kernel_size=1)
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
self.cond_mapper = nn.Sequential(
linear_cls(c_cond, c),
nn.Linear(c_cond, c),
nn.LeakyReLU(0.2),
linear_cls(c, c),
nn.Linear(c, c),
)

self.blocks = nn.ModuleList()
Expand All @@ -58,7 +56,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
self.out = nn.Sequential(
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
conv_cls(c, c_in * 2, kernel_size=1),
nn.Conv2d(c, c_in * 2, kernel_size=1),
)

self.gradient_checkpointing = False
Expand Down

0 comments on commit 000fa82

Please sign in to comment.