Skip to content

Commit

Permalink
Update configs to work with a filtered raw + add debug script with pdb
Browse files Browse the repository at this point in the history
  • Loading branch information
kdu4108 committed Aug 2, 2024
1 parent 4f15e2e commit a23098e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ train:
# Data can either be local or on cloud storage (e.g. S3), see data docs for more info
# Use braceexpand notation to indicate shard range (e.g. shard-{0000..9999}.tar)
# Use brackets to indicate multiple modalities (e.g. [modality1,modality2,modality3])
data_path: '/store/swissai/a08/data/4m/splits2/train/[video_rgb,video_rgb_tok]/00000{00000..00100}.tar' # TODO: need to reformat the data correctly here.
data_path: '/store/swissai/a08/data/4m/cleaned/train/[video_rgb,video_tok_rgb]/0000000000.tar' # TODO: need to reformat the data correctly here.
use_wds: True # Use webdataset
wds_n_repeats: 4 # Number of repeats for webdataset loader to improve efficiency
wds_shuffle_buffer_tar: 1_000 # Webdatasets shuffle buffer after loading tar files
Expand All @@ -41,4 +41,4 @@ train:
# val:
# datasets:
# my_video_dataset:
# data_path: '/store/swissai/a08/data/4m/val/[video_rgb,video_rgb_tok]/00000{00175..00199}.tar'
# data_path: '/store/swissai/a08/data/4m/val/[video_rgb,video_tok_rgb]/00000{00175..00199}.tar'
5 changes: 3 additions & 2 deletions fourm/data/unified_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def build_fm_transfer_dataset(
def _keyless_map(data, f, handler=reraise_exception):
"""Map samples without adding __key__."""
for sample in data:
import pdb; pdb.set_trace()
try:
result = f(sample)
except Exception as exn:
Expand Down Expand Up @@ -390,7 +391,7 @@ def build_wds_fm_pretraining_dataloader(

if batch_size is not None:
# Perform multi-threaded dataloading
return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None)
return wds.WebLoader(datapipe, num_workers=0, batch_size=None)
else:
return datapipe

Expand Down Expand Up @@ -552,6 +553,6 @@ def build_mixture_dataloader(data_iters, weights, modality_info, batch_size, num
wds.batched(batch_size, collation_fn=default_collate, partial=False),
).with_epoch(epoch_size // (num_gpus * num_workers * batch_size)) # Pre-define iterator length

mixture_loader = wds.WebLoader(mixture_pipe, num_workers=num_workers, batch_size=None)
mixture_loader = wds.WebLoader(mixture_pipe, num_workers=0, batch_size=None)

return mixture_loader
5 changes: 3 additions & 2 deletions run_training_4m.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# python run_training_4m.py --data_config cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml
import argparse
# FOR DEBUGGING with pdb:
# 1. make sure the wds.WebLoader(num_workers=0) in a) build_wds_fm_pretraining_dataloader and b) build_mixture_dataloader
# python -m torch.distributed.launch --nproc_per_node 1 --use-env run_training_4m.py --config /store/swissai/a08/kdu/ml-4m/cfgs/default/4m/models/video/4m-b_mod3.yamlimport argparse
import datetime
import json
import math
Expand Down

0 comments on commit a23098e

Please sign in to comment.