Official PyTorch implementaton of paper :Multi_Source_Augmentation_and_Composite_Prompts_for_Visual_Recognition_with_Missing_Modality
This paper introduces an innovative multi-source augmentation and composite prompts method to alleviate missing modality. We used retrieval-based and generation-based methods to recover the missing modal data and designed a gate unit to enable the model to automatically select better sources of augmentation data. Then, we design a composite prompts method to fine-tune the model, guiding it to handle various input situations better and discover connections between them.
Python = 3.7.13
Pytorch = 1.10.0
CUDA = 11.3
pip install -r requirements.txt
We use two vision and language datasets: MM-IMDb, UPMC Food-101, . Please download the datasets by yourself.
We use the CLIP model to get the retrieval augmentation data and Unidiffuser to get the generation augmentation data. the example codes are located in data_augmentation/retrieval/retrieval.py
and data_augmentation/generation/unidiffuser/sample.py
. These two files are written using the mmimdb dataset as examples. If you want to run these two files to produce augmentation data, make sure you have downloaded the dataset. And check to modify the path data in the code.
Please see DATA.md
to organize the datasets
We use pyarrow
to serialize the datasets, the conversion codes are located in vilt/utils/wirte_*.py
. Run the following script to create the pyarrow binary file:
python make_arrow.py --dataset [DATASET] --root [YOUR_DATASET_ROOT]
python run.py with data_root=<ARROW_ROOT> \
num_gpus=<NUM_GPUS> \
num_nodes=<NUM_NODES> \
per_gpu_batchsize=<BS_FITS_YOUR_GPU> \
<task_finetune_mmimdb or task_finetune_food101> \
load_path=<MODEL_PATH> \
exp_name=<EXP_NAME> \
prompt_type=<PROMPT_TYPE> \
test_ratio=<TEST_RATIO> \
test_type=<TEST_TYPE> \
add_type=both
test_only=True
- Download the pre-trained ViLT model weights from here.
- Start to train.
python run.py with data_root=<ARROW_ROOT> \
num_gpus=<NUM_GPUS> \
num_nodes=<NUM_NODES> \
per_gpu_batchsize=<BS_FITS_YOUR_GPU> \
<task_finetune_mmimdb or task_finetune_food101> \
load_path=<PRETRAINED_MODEL_PATH> \
exp_name=<EXP_NAME>
This code is based on ViLT and Missing-aware-prompts. The code for the data augmentation part comes from CLIP and Unidiffuser