From efa0e8c866f2d030158eaa423ec9e07dd816bf41 Mon Sep 17 00:00:00 2001 From: julie wang Date: Mon, 23 Sep 2024 16:15:05 +0200 Subject: [PATCH] feat(ml): option for max_sequence_lenght of video generation --- models/diffusion_networks.py | 4 ++-- .../unet_generator_attn/unet_generator_attn_vid.py | 13 ++++--------- options/common_options.py | 6 +++--- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/models/diffusion_networks.py b/models/diffusion_networks.py index 521fb4a68..e9848ec81 100644 --- a/models/diffusion_networks.py +++ b/models/diffusion_networks.py @@ -39,7 +39,7 @@ def define_G( G_attn_nb_mask_attn, G_attn_nb_mask_input, G_spectral, - G_unet_vid_max_frame, + G_unet_vid_max_sequence_length, jg_dir, G_padding_type, G_config_segformer, @@ -150,7 +150,7 @@ def define_G( efficient=G_unet_mha_vit_efficient, cond_embed_dim=cond_embed_dim, freq_space=train_feat_wavelet, - max_frame=G_unet_vid_max_frame, + max_sequence_length=G_unet_vid_max_sequence_length, ) elif G_netG == "unet_mha_ref_attn": diff --git a/models/modules/unet_generator_attn/unet_generator_attn_vid.py b/models/modules/unet_generator_attn/unet_generator_attn_vid.py index 9fd2b0ab2..9fa818233 100644 --- a/models/modules/unet_generator_attn/unet_generator_attn_vid.py +++ b/models/modules/unet_generator_attn/unet_generator_attn_vid.py @@ -397,7 +397,6 @@ def __init__( temporal_position_encoding=temporal_position_encoding, temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) - if zero_initialize: self.temporal_transformer.proj_out = zero_module( self.temporal_transformer.proj_out @@ -532,7 +531,6 @@ def __init__( temporal_position_encoding_max_len=25, ): super().__init__() - attention_blocks = [] norms = [] @@ -555,7 +553,6 @@ def __init__( ) ) norms.append(nn.LayerNorm(dim)) - self.attention_blocks = nn.ModuleList(attention_blocks) self.norms = nn.ModuleList(norms) @@ -962,7 +959,6 @@ def __init__( self.attention_mode = attention_mode self.is_cross_attention = kwargs["cross_attention_dim"] is not None - self.pos_encoder = ( PositionalEncoding( kwargs["query_dim"], @@ -1109,7 +1105,7 @@ def __init__( use_new_attention_order=True, # False, efficient=False, freq_space=False, - max_frame=25, + max_sequence_length=25, ): super().__init__() @@ -1132,8 +1128,7 @@ def __init__( self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.freq_space = freq_space - self.max_frame = max_frame - + self.max_sequence_length = max_sequence_length if self.freq_space: from ..freq_utils import InverseHaarTransform, HaarTransform @@ -1188,7 +1183,7 @@ def __init__( attention_block_types=("Temporal_self", "Temporal_Self"), cross_frame_attention_mode=None, temporal_position_encoding=True, - temporal_position_encoding_max_len=25, # self.max_frame, + temporal_position_encoding_max_len=self.max_sequence_length, temporal_attention_dim_div=1, zero_initialize=True, ) @@ -1294,7 +1289,7 @@ def __init__( attention_block_types=("Temporal_self", "Temporal_Self"), cross_frame_attention_mode=None, temporal_position_encoding=True, - temporal_position_encoding_max_len=25, # self.max_frame, + temporal_position_encoding_max_len=self.max_sequence_length, temporal_attention_dim_div=1, zero_initialize=True, ) diff --git a/options/common_options.py b/options/common_options.py index c4ae006b5..1de293fcd 100644 --- a/options/common_options.py +++ b/options/common_options.py @@ -387,10 +387,10 @@ def initialize(self, parser): help="Patch size for HDIT, e.g. 4 for 4x4 patches", ) parser.add_argument( - "--G_unet_vid_max_frame", - default=24, + "--G_unet_vid_max_sequence_length", + default=25, type=int, - help="max frame number for unet_vid in the PositionalEncoding", + help="max frame number(sequence length) for unet_vid in the PositionalEncoding", ) # parser.add_argument(