From 000fa82a1e6a7730dddb666ff6b2f681de3c2746 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 2 Apr 2024 13:01:04 +0530 Subject: [PATCH] [Chore] remove class assignments for linear and conv. (#7553) * remove class assignments for linear and conv. * fix: self.nn --- src/diffusers/models/attention.py | 3 +-- src/diffusers/models/attention_processor.py | 19 +++++++---------- src/diffusers/models/downsampling.py | 3 +-- src/diffusers/models/embeddings.py | 5 ++--- src/diffusers/models/resnet.py | 21 +++++++------------ .../models/transformers/transformer_2d.py | 11 ++++------ .../models/unets/unet_stable_cascade.py | 9 ++++---- src/diffusers/models/upsampling.py | 3 +-- .../wuerstchen/modeling_wuerstchen_common.py | 15 +++++-------- .../wuerstchen/modeling_wuerstchen_prior.py | 10 ++++----- 10 files changed, 38 insertions(+), 61 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 651c928adc39..50866e3a7a8c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -634,7 +634,6 @@ def __init__( if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - linear_cls = nn.Linear if activation_fn == "gelu": act_fn = GELU(dim, inner_dim, bias=bias) @@ -651,7 +650,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0c6dfe068d5c..1fd29ce708c8 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -181,25 +181,22 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - linear_cls = nn.Linear - - self.linear_cls = linear_cls - self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -706,7 +703,7 @@ def fuse_projections(self, fuse=True): out_features = concatenated_weights.shape[0] # create a new single projection layer and copy over the weights. - self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) self.to_qkv.weight.copy_(concatenated_weights) if self.use_bias: concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) @@ -717,7 +714,7 @@ def fuse_projections(self, fuse=True): in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] - self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) self.to_kv.weight.copy_(concatenated_weights) if self.use_bias: concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index 9ae28e950e83..6d556e1e67ac 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -102,7 +102,6 @@ def __init__( self.padding = padding stride = 2 self.name = name - conv_cls = nn.Conv2d if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) @@ -114,7 +113,7 @@ def __init__( raise ValueError(f"unknown norm_type: {norm_type}") if use_conv: - conv = conv_cls( + conv = nn.Conv2d( self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) else: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c15ff24cbcda..85b1e4944ed2 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -199,9 +199,8 @@ def __init__( sample_proj_bias=True, ): super().__init__() - linear_cls = nn.Linear - self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias) + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) @@ -214,7 +213,7 @@ def __init__( time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias) + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ec75861e2da0..88c7a01be6bf 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -101,8 +101,6 @@ def __init__( self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm - conv_cls = nn.Conv2d - if groups_out is None: groups_out = groups @@ -113,7 +111,7 @@ def __init__( else: raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") - self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.time_embedding_norm == "ada_group": # ada_group self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) @@ -125,7 +123,7 @@ def __init__( self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) @@ -139,7 +137,7 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = conv_cls( + self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, @@ -263,21 +261,18 @@ def __init__( self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act - linear_cls = nn.Linear - conv_cls = nn.Conv2d - if groups_out is None: groups_out = groups self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": - self.time_emb_proj = linear_cls(temb_channels, out_channels) + self.time_emb_proj = nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) + self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: @@ -287,7 +282,7 @@ def __init__( self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) @@ -313,7 +308,7 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = conv_cls( + self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 6b2cd0431231..0658a7daa241 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -117,9 +117,6 @@ def __init__( self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim - conv_cls = nn.Conv2d - linear_cls = nn.Linear - # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = (in_channels is not None) and (patch_size is None) @@ -159,9 +156,9 @@ def __init__( self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: - self.proj_in = linear_cls(in_channels, inner_dim) + self.proj_in = nn.Linear(in_channels, inner_dim) else: - self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -222,9 +219,9 @@ def __init__( if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: - self.proj_out = linear_cls(inner_dim, in_channels) + self.proj_out = nn.Linear(inner_dim, in_channels) else: - self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 197ddeec757d..6227f7413a3c 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -41,11 +41,11 @@ def forward(self, x): class SDCascadeTimestepBlock(nn.Module): def __init__(self, c, c_timestep, conds=[]): super().__init__() - linear_cls = nn.Linear - self.mapper = linear_cls(c_timestep, c * 2) + + self.mapper = nn.Linear(c_timestep, c * 2) self.conds = conds for cname in conds: - setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2)) + setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2)) def forward(self, x, t): t = t.chunk(len(self.conds) + 1, dim=1) @@ -94,12 +94,11 @@ def forward(self, x): class SDCascadeAttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() - linear_cls = nn.Linear self.self_attn = self_attn self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) - self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c)) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) def forward(self, x, kv): kv = self.kv_mapper(kv) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 4ecf6ebc26d2..af6e15db308b 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -110,7 +110,6 @@ def __init__( self.use_conv_transpose = use_conv_transpose self.name = name self.interpolate = interpolate - conv_cls = nn.Conv2d if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) @@ -131,7 +130,7 @@ def __init__( elif use_conv: if kernel_size is None: kernel_size = 3 - conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index 101acafcff1f..73e71b3076fb 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -17,8 +17,8 @@ def forward(self, x): class TimestepBlock(nn.Module): def __init__(self, c, c_timestep): super().__init__() - linear_cls = nn.Linear - self.mapper = linear_cls(c_timestep, c * 2) + + self.mapper = nn.Linear(c_timestep, c * 2) def forward(self, x, t): a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) @@ -29,13 +29,10 @@ class ResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() - conv_cls = nn.Conv2d - linear_cls = nn.Linear - - self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( - linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c) + nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) ) def forward(self, x, x_skip=None): @@ -64,12 +61,10 @@ class AttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() - linear_cls = nn.Linear - self.self_attn = self_attn self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) - self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c)) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) def forward(self, x, kv): kv = self.kv_mapper(kv) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 8cc294eaf79a..a59661c3c3f5 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -40,15 +40,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() - conv_cls = nn.Conv2d - linear_cls = nn.Linear self.c_r = c_r - self.projection = conv_cls(c_in, c, kernel_size=1) + self.projection = nn.Conv2d(c_in, c, kernel_size=1) self.cond_mapper = nn.Sequential( - linear_cls(c_cond, c), + nn.Linear(c_cond, c), nn.LeakyReLU(0.2), - linear_cls(c, c), + nn.Linear(c, c), ) self.blocks = nn.ModuleList() @@ -58,7 +56,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) self.out = nn.Sequential( WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), - conv_cls(c, c_in * 2, kernel_size=1), + nn.Conv2d(c, c_in * 2, kernel_size=1), ) self.gradient_checkpointing = False