Official PyTorch implementation of "Multimodal Across Domains Gaze Target Detection" at ICMI 2022.
To run this repo create a new conda environment and configure all environmental variables using the provided templates.
conda env create -f environment.yml
cp .env.example .env
nano .env
Due to the complexity of the network use a recent NVidia GPU with at least 6GB of memory available and CUDA 11.3+ installed. Also, we suggest running everything on a Linux-based OS, preferably Ubuntu 20.04.
This network was trained and evaluated on three popular datasets: GazeFollow (extended), VideoAttentionTarget, and GOO (real). We further extended each sample with depth data. You can extract the depth maps using the provided scripts:
# GazeFollow
python scripts/gazefollow_get_depth.py --dataset_dir /path/to/gazefollow_extended
python scripts/videoattentiontarget_get_depth.py --dataset_dir /path/to/videoattentiontarget
python scripts/goo_get_depth.py --dataset_dir /path/to/goo/real
Before training, download the pretraining weights here.
The script allows to train and evaluate different datasets.
To train and evaluate on the same dataset sets the source_dataset
and target_dataset
to the same value.
To evaluate only, set the ‵eval_weights‵ variable. We also release our trained checkpoints for GazeFollow and VideoAttentionTarget.
python main.py [-h] [--tag TAG] [--device {cpu,cuda,mps}] [--input_size INPUT_SIZE] [--output_size OUTPUT_SIZE] [--batch_size BATCH_SIZE]
[--source_dataset_dir SOURCE_DATASET_DIR] [--source_dataset {gazefollow,videoattentiontarget,goo}] [--target_dataset_dir TARGET_DATASET_DIR]
[--target_dataset {gazefollow,videoattentiontarget,goo}] [--num_workers NUM_WORKERS] [--init_weights INIT_WEIGHTS] [--eval_weights EVAL_WEIGHTS] [--lr LR]
[--epochs EPOCHS] [--evaluate_every EVALUATE_EVERY] [--save_every SAVE_EVERY] [--print_every PRINT_EVERY] [--no_resume] [--output_dir OUTPUT_DIR] [--amp AMP]
[--freeze_scene] [--freeze_face] [--freeze_depth] [--head_da] [--rgb_depth_da] [--task_loss_amp_factor TASK_LOSS_AMP_FACTOR]
[--rgb_depth_source_loss_amp_factor RGB_DEPTH_SOURCE_LOSS_AMP_FACTOR] [--rgb_depth_target_loss_amp_factor RGB_DEPTH_TARGET_LOSS_AMP_FACTOR]
[--adv_loss_amp_factor ADV_LOSS_AMP_FACTOR] [--no_wandb] [--no_save]
optional arguments:
-h, --help show this help message and exit
--tag TAG Description of this run
--device {cpu,cuda,mps}
--input_size INPUT_SIZE
input size
--output_size OUTPUT_SIZE
output size
--batch_size BATCH_SIZE
batch size
--source_dataset_dir SOURCE_DATASET_DIR
directory where the source dataset is located
--source_dataset {gazefollow,videoattentiontarget,goo}
--target_dataset_dir TARGET_DATASET_DIR
directory where the target dataset is located
--target_dataset {gazefollow,videoattentiontarget,goo}
--num_workers NUM_WORKERS
--init_weights INIT_WEIGHTS
initial weights
--eval_weights EVAL_WEIGHTS
If set, performs evaluation only
--lr LR learning rate
--epochs EPOCHS number of epochs
--evaluate_every EVALUATE_EVERY
evaluate every N epochs
--save_every SAVE_EVERY
save model every N epochs
--print_every PRINT_EVERY
print training stats every N batches
--no_resume Resume from a stopped run if exists
--output_dir OUTPUT_DIR
Path to output folder
--amp AMP AMP optimization level
--freeze_scene Freeze the scene backbone
--freeze_face Freeze the head backbone
--freeze_depth Freeze the depth backbone
--head_da Do DA on head backbone
--rgb_depth_da Do DA on rgb/depth backbone
--task_loss_amp_factor TASK_LOSS_AMP_FACTOR
--rgb_depth_source_loss_amp_factor RGB_DEPTH_SOURCE_LOSS_AMP_FACTOR
--rgb_depth_target_loss_amp_factor RGB_DEPTH_TARGET_LOSS_AMP_FACTOR
--adv_loss_amp_factor ADV_LOSS_AMP_FACTOR
--no_wandb Disables wandb
--no_save Do not save checkpoint every {save_every}. Stores last checkpoint only to allow resuming
If you use our code, please cite:
@inproceedings{tonini2022multimodal,
author = {Tonini, Francesco and Beyan, Cigdem and Ricci, Elisa},
title = {Multimodal Across Domains Gaze Target Detection},
year = {2022},
isbn = {9781450393904},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
url = {https://doi.org/10.1145/3536221.3556624%7D,
doi = {10.1145/3536221.3556624},
booktitle = {Proceedings of the 2022 International Conference on Multimodal Interaction},
pages = {420–431},
series = {ICMI '22}
}