Skip to content

Commit

Permalink
feat(ml): option for max_sequence_lenght of video generation
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Sep 25, 2024
1 parent 2c53948 commit efa0e8c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
4 changes: 2 additions & 2 deletions models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def define_G(
G_attn_nb_mask_attn,
G_attn_nb_mask_input,
G_spectral,
G_unet_vid_max_frame,
G_unet_vid_max_sequence_length,
jg_dir,
G_padding_type,
G_config_segformer,
Expand Down Expand Up @@ -150,7 +150,7 @@ def define_G(
efficient=G_unet_mha_vit_efficient,
cond_embed_dim=cond_embed_dim,
freq_space=train_feat_wavelet,
max_frame=G_unet_vid_max_frame,
max_sequence_length=G_unet_vid_max_sequence_length,
)

elif G_netG == "unet_mha_ref_attn":
Expand Down
13 changes: 4 additions & 9 deletions models/modules/unet_generator_attn/unet_generator_attn_vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def __init__(
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)

if zero_initialize:
self.temporal_transformer.proj_out = zero_module(
self.temporal_transformer.proj_out
Expand Down Expand Up @@ -532,7 +531,6 @@ def __init__(
temporal_position_encoding_max_len=25,
):
super().__init__()

attention_blocks = []
norms = []

Expand All @@ -555,7 +553,6 @@ def __init__(
)
)
norms.append(nn.LayerNorm(dim))

self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)

Expand Down Expand Up @@ -962,7 +959,6 @@ def __init__(

self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None

self.pos_encoder = (
PositionalEncoding(
kwargs["query_dim"],
Expand Down Expand Up @@ -1109,7 +1105,7 @@ def __init__(
use_new_attention_order=True, # False,
efficient=False,
freq_space=False,
max_frame=25,
max_sequence_length=25,
):
super().__init__()

Expand All @@ -1132,8 +1128,7 @@ def __init__(
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.freq_space = freq_space
self.max_frame = max_frame

self.max_sequence_length = max_sequence_length
if self.freq_space:
from ..freq_utils import InverseHaarTransform, HaarTransform

Expand Down Expand Up @@ -1188,7 +1183,7 @@ def __init__(
attention_block_types=("Temporal_self", "Temporal_Self"),
cross_frame_attention_mode=None,
temporal_position_encoding=True,
temporal_position_encoding_max_len=25, # self.max_frame,
temporal_position_encoding_max_len=self.max_sequence_length,
temporal_attention_dim_div=1,
zero_initialize=True,
)
Expand Down Expand Up @@ -1294,7 +1289,7 @@ def __init__(
attention_block_types=("Temporal_self", "Temporal_Self"),
cross_frame_attention_mode=None,
temporal_position_encoding=True,
temporal_position_encoding_max_len=25, # self.max_frame,
temporal_position_encoding_max_len=self.max_sequence_length,
temporal_attention_dim_div=1,
zero_initialize=True,
)
Expand Down
6 changes: 3 additions & 3 deletions options/common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,10 @@ def initialize(self, parser):
help="Patch size for HDIT, e.g. 4 for 4x4 patches",
)
parser.add_argument(
"--G_unet_vid_max_frame",
default=24,
"--G_unet_vid_max_sequence_length",
default=25,
type=int,
help="max frame number for unet_vid in the PositionalEncoding",
help="max frame number(sequence length) for unet_vid in the PositionalEncoding",
)

# parser.add_argument(
Expand Down

0 comments on commit efa0e8c

Please sign in to comment.