Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modality info and transforms for raw RGB, tokens, video descriptions, video transcripts, and video bounding boxes. #1

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Half of all samples from this mixture are rgb2dense, half are all2all
rgb@224:
input_alphas: [1000.0, 0.5]
target_alphas: [0., 0.] # RGB is not a target
rgb@224: # num cols is the number of sampling strategies to try.
input_alphas: [1000.0, 0.5] # NOTE: input alphas: how much weight to put on the input modalities for encoding.
target_alphas: [0., 0.] # RGB is not a target # NOTE: target alphas - how much weight to put on the modalities to be unmasked/sampled during decoding.
tok_rgb@224:
input_alphas: [0., 0.5]
target_alphas: [0.5, 0.5]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# What's up with the naming configs here?
video_rgb@224:
input_alphas: [1000.0, 0.5]
target_alphas: [0., 0.] # RGB is not a target
video_tok_rgb@224:
input_alphas: [0., 0.5]
target_alphas: [0.5, 0.5]
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# What's up with the naming configs here?
video_rgb@224:
input_alphas: [1000.0, 0.5]
target_alphas: [0., 0.] # RGB is not a target
video_tok_rgb@224:
input_alphas: [0., 0.5]
target_alphas: [0.5, 0.5]
det:
input_alphas: [0., 0.5]
target_alphas: [0.5, 0.5]
keep: ['random', 'random']
44 changes: 44 additions & 0 deletions cfgs/default/4m/data/video/mix_mod3_rgb_tok_det_to_all_a0.5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
train:
datasets:
my_video_dataset:
type: multimodal

# Input and output domain names, separated by hyphen
in_domains: video_rgb@224-video_tok_rgb@224-video_det
out_domains: video_rgb@224-video_tok_rgb@224-video_det

# Dirichlet alphas concentration parameter for input and output.
# Can be either one value, or one value per input modality separated by hyphen.
input_alphas: null
target_alphas: null
# Path to specific alphas configuration to enable mixture of Dirichlets.
# If provided, overrides input_alphas and target_alphas
alphas_config: "cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_det_a0.5.yaml"

# Optionally, min_input_tokens, min_target_tokens, num_input_tokens, num_target_tokens can be specified here
# If so, they will override the values provided in the main config

# Data can either be local or on cloud storage (e.g. S3), see data docs for more info
# Use braceexpand notation to indicate shard range (e.g. shard-{0000..9999}.tar)
# Use brackets to indicate multiple modalities (e.g. [modality1,modality2,modality3])
data_path: '/store/swissai/a08/data/4m/train/[video_rgb,video_rgb_tok,video_det]/shard-{00000..00100}.tar' # TODO: need to reformat the data correctly here.
use_wds: True # Use webdataset
wds_n_repeats: 4 # Number of repeats for webdataset loader to improve efficiency
wds_shuffle_buffer_tar: 1_000 # Webdatasets shuffle buffer after loading tar files
wds_shuffle_buffer_repeat: 1_000 # Webdatasets shuffle buffer after repeating samples

main_augment_domain: video_rgb@224 # Select from which modality to get the original full image size (mostly important for resizing bounding boxes)
aligned_captions: True # Align captions to crop_settings # TODO: tbd?
tok_train_aug: True # Apply data augmentation to tokens (if multiple crop settings are available) # TODO: tbd?

# modality_name_map: # Use modality_name_map to define a mapping from a folder name to a modality name
# tok_rgb_folder_name: tok_rgb@224
# tok_depth_folder_nme: tok_depth@224
# ...

weights: [1.0] # Sampling weights for the training datasets

val:
datasets:
my_video_dataset:
data_path: '/store/swissai/a08/data/4m/val/[video_rgb,video_rgb_tok,video_det]/shard-{00000..00100}.tar'
44 changes: 44 additions & 0 deletions cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
train:
datasets:
my_video_dataset:
type: multimodal

# Input and output domain names, separated by hyphen
in_domains: video_rgb@224-video_tok_rgb@224
out_domains: video_rgb@224-video_tok_rgb@224

# Dirichlet alphas concentration parameter for input and output.
# Can be either one value, or one value per input modality separated by hyphen.
input_alphas: null
target_alphas: null
# Path to specific alphas configuration to enable mixture of Dirichlets.
# If provided, overrides input_alphas and target_alphas
alphas_config: "cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_a0.5.yaml"

# Optionally, min_input_tokens, min_target_tokens, num_input_tokens, num_target_tokens can be specified here
# If so, they will override the values provided in the main config

# Data can either be local or on cloud storage (e.g. S3), see data docs for more info
# Use braceexpand notation to indicate shard range (e.g. shard-{0000..9999}.tar)
# Use brackets to indicate multiple modalities (e.g. [modality1,modality2,modality3])
data_path: '/store/swissai/a08/data/4m/cleaned/train/[video_rgb,video_tok_rgb]/0000000000.tar' # TODO: need to reformat the data correctly here.
use_wds: True # Use webdataset
wds_n_repeats: 4 # Number of repeats for webdataset loader to improve efficiency
wds_shuffle_buffer_tar: 1_000 # Webdatasets shuffle buffer after loading tar files
wds_shuffle_buffer_repeat: 1_000 # Webdatasets shuffle buffer after repeating samples

main_augment_domain: video_rgb@224 # Select from which modality to get the original full image size (mostly important for resizing bounding boxes)
aligned_captions: True # Align captions to crop_settings # TODO: tbd?
tok_train_aug: True # Apply data augmentation to tokens (if multiple crop settings are available) # TODO: tbd?

# modality_name_map: # Use modality_name_map to define a mapping from a folder name to a modality name
# tok_rgb_folder_name: tok_rgb@224
# tok_depth_folder_nme: tok_depth@224
# ...

weights: [1.0] # Sampling weights for the training datasets

# val:
# datasets:
# my_video_dataset:
# data_path: '/store/swissai/a08/data/4m/val/[video_rgb,video_tok_rgb]/00000{00175..00199}.tar'
46 changes: 46 additions & 0 deletions cfgs/default/4m/models/video/4m-b_mod3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Config for DDP

# Arch: SwiGLU No Bias
# Modalities: Mix of rgb2all and all2all, with alphas=0.5
# To be run on 64 GPUs for batch size = 8192
run_name: auto

# Input & output
num_input_tokens: 128
num_target_tokens: 128
loss_type: mod

# Architecture
model: fm_base_12e_12d_swiglu_nobias
patch_size: 16
input_size: 224
dtype: bfloat16
tokenizer_path: "fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"

# Train
epochs: -1
total_tokens: 500 # in billions
opt: adamw
blr: 0.0001 # this is base_lr = 1e-4, lr = base_lr * batch_size / 256
min_blr: 0.
warmup_epochs: -1
warmup_tokens: 10 # in billions
batch_size: 128 # 128 x 64 = 8192

# Data

data_config: "cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml"
s3_data_endpoint: null # Change me
eval_freq: 1
fixed_eval: True
epoch_size: 10_000_000 # Number of samples per "epoch"

# Saving
save_ckpt_freq: 1
output_dir: 'output/auto'

# Wandb
log_wandb: False # Set to True to log to Weights & Biases
wandb_project: '4m-train'
wandb_entity: null # Change if needed
wandb_run_name: auto
2 changes: 1 addition & 1 deletion fourm/data/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def chunk_span_masking(sequence_chunks: List[List[int]], sentinel_to_id: Dict[in



class UnifiedMasking(object):
class UnifiedMasking(object): # this defines the masking logic
def __init__(self,
modality_info: Dict,
text_tokenizer: Optional[Tokenizer],
Expand Down
49 changes: 48 additions & 1 deletion fourm/data/modality_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
ColorPaletteTransform,
SAMInstanceTokTransform,
SAMInstanceTransform,
VideoDescriptionTransform,
VideoDetectionTransform,
VideoRGBTransform,
VideoTokTransform,
VideoTranscriptTransform,
)
from fourm.models.decoder_embeddings import ImageTokenDecoderEmbedding, SequenceDecoderEmbedding
from fourm.models.encoder_embeddings import (
Expand All @@ -39,6 +44,7 @@
)
from fourm.utils import generate_uint15_hash

# Specifications about different modalities
MODALITY_INFO = {
# 4M-7 modalities
"rgb@224": {
Expand Down Expand Up @@ -406,6 +412,38 @@
},
}

VIDEO_MODALITY_INFO = {
### Video modalities
# TODO: do we need to keep the image versions? These probably should generalize over those, right?
"video_rgb@224": {
**MODALITY_INFO["rgb@224"],
"id": generate_uint15_hash("video_rgb@224"),
"path": "video_rgb", # TODO: video_rgb or keep aas rgb? (probably only keep rgb if this generalizes over single images too)
},
"video_description": {
**MODALITY_INFO["caption"], # TODO: do we want to increase the default 'max_tokens/max_length' from 256?
"id": generate_uint15_hash("video_description"),
},
"video_transcript": {
**MODALITY_INFO["caption"], # TODO: do we want to increase the default 'max_tokens/max_length' from 256?
"id": generate_uint15_hash("video_transcript"),
},
"video_det": {
**MODALITY_INFO["det"],
"id": generate_uint15_hash("video_det"),
},
"video_tok_rgb@224": {
**MODALITY_INFO["tok_rgb@224"],
"id": generate_uint15_hash("video_tok_rgb@224"),
},
"video_tok_clip@224": {
**MODALITY_INFO["tok_clip@224"],
"id": generate_uint15_hash("video_tok_clip@224"),
},
}

MODALITY_INFO = {**MODALITY_INFO, **VIDEO_MODALITY_INFO}

# Note: @res suffix is ignored for modality transforms
MODALITY_TRANSFORMS = {
# 4M-7 modalities
Expand All @@ -414,7 +452,7 @@
"det": DetectionTransform(
det_threshold=0.6, det_max_instances=None, bbox_order="dist_to_orig", coord_bins=1000, min_visibility=0.0
),
"tok_rgb": TokTransform(),
"tok_rgb": TokTransform(), # tok_ indicates its a token representation
"tok_depth": TokTransform(),
"tok_normal": TokTransform(),
"tok_semseg": TokTransform(),
Expand All @@ -435,6 +473,15 @@
"tok_imagebind_global": TokTransform(),
# Other
"mask_valid": MaskTransform(mask_pool_size=1),
# Video
"video_rgb": VideoRGBTransform(imagenet_default_mean_and_std=True), # TODO: check parameters
"video_tok_rgb": VideoTokTransform(), # tok_ indicates its a token representation
"video_tok_clip": VideoTokTransform(), # TODO: check parameters
"video_description": VideoDescriptionTransform(aligned_captions=True), # TODO: check parameters
"video_transcript": VideoTranscriptTransform(aligned_captions=True), # TODO: check parameters
"video_det": VideoDetectionTransform(
det_threshold=0.6, det_max_instances=None, bbox_order="dist_to_orig", coord_bins=1000, min_visibility=0.0
), # TODO: check parameters
}

MODALITY_TRANSFORMS_DIVAE = {
Expand Down
Loading