[2024/08/10]
arxiv paper release.
Environment
All models were trained and tested on Ubuntu 18.04 with Python 3.8.17 and PyTorch 2.0.1 with CUDA 11.8.
You can start by creating a conda virtual environment:
conda create -n dada python=3.8 -y
source activate dada
pip install -r requirements.txt
Dataset
Preprocessed ETH and UCY datasets are included in this repository, under ./datasets/
.
Among these datasets, train_origin
, val
and test
are obtained directly from the ETH and UCY datasets, and train
is obtained after the DLA processing.
We have provided an example A2B
in ./datasets/A2B/
, which demonstrates how to set up the dataset for a particular cross-domain task. If you want to construct another S2T
dataset, please follow the step below (here use B2C
dataset as an example):
- create a folder named
B2C
under./datasets/
; - create four folders named
train_origin
,train
,val
andtest
under./datasets/B2C/
; - put the B-domain (HOTEL) training set into new-created
train_origin
andtrain
folders; put the C-domain (UNIV) validation set intoval
folder; put the C-domain testing set intotest
folder; - train the corresponding DLA model to automatically generate the aligned source data in
train
folder.
Baseline Models
This repository supports three baseline models: Social-GAN, Trajectron++ and TUTR.
Their DADA-modified source code are in ./models/
.
To train and evaluate our DADA-model on the A2B task at once, we provide a bash script train.sh
for a simplified execution.
bash ./train_DADA.sh -b <baseline_model> # quickly train
bash ./test_DADA.sh -b <baseline_model> # quickly evaluate
where <baseline_model> could be sgan
, trajectron++
or tutr
.
For example:
bash ./train_DADA.sh -b sgan # quickly train
bash ./test_DADA.sh -b sgan # quickly evaluate
The DLA network could to be trained by:
cd ./DLA/
python train_DLA.py --subset <task_S2T>
For example:
python train_DLA.py --subset A2B
After finishing training, the aligned source data will be automatically generated in ./datasets/subset/train/
.
Given that our repository supports three baseline models, here we take the Social-GAN as example.
Training for Baseline
The baseline model is directly trained without DLA data:
cd ./models/sgan/scripts/
python train.py --dataset_name <task_S2T>
Training for DLA
The DLA model is trained with DLA data, so you just need to modify the train_set path to f'../../../datasets/{args.dataset_name}/train'
and modify the checkpoint_save path to '../checkpoint/checkpoint_DLA'
.
Training for DADA
The DADA model is further embedded an discriminator w.r.t DLA model during training phase:
cd ./models/sgan/scripts/
python train_DADA.py --dataset_name <task_S2T>
You can find the code of the discriminator structure and its training procedure.
Given that our repository supports three baseline models, here we take the Social-GAN as example.
Pretrained Models
We have included pre-trained models in ./models/sgan/checkpoint/
folder that can be directly used to evaluate models.
You can simply view the DADA evaluation result for A2B task by running:
cd ./models/sgan/scripts/
python evaluate_model.py --dataset_name <task_S2T>
To view the baseline and DLA evaluation result, you just need to modify the checkpoint_load path.