Skip to content
/ DADA Public
forked from jackyzengl/DADA

Dual-Alignment Domain Adaptation for Pedestrian Trajectory Prediction

Notifications You must be signed in to change notification settings

haofuly/DADA

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DADA

Dual-Alignment Domain Adaptation for Pedestrian Trajectory Prediction


News

  • [2024/08/10] arxiv paper release.

Setup

Environment
All models were trained and tested on Ubuntu 18.04 with Python 3.8.17 and PyTorch 2.0.1 with CUDA 11.8. You can start by creating a conda virtual environment:

conda create -n dada python=3.8 -y
source activate dada
pip install -r requirements.txt

Dataset
Preprocessed ETH and UCY datasets are included in this repository, under ./datasets/.
Among these datasets, train_origin, val and test are obtained directly from the ETH and UCY datasets, and train is obtained after the DLA processing.
We have provided an example A2B in ./datasets/A2B/, which demonstrates how to set up the dataset for a particular cross-domain task. If you want to construct another S2T dataset, please follow the step below (here use B2C dataset as an example):

  • create a folder named B2C under ./datasets/;
  • create four folders named train_origin, train, val and test under ./datasets/B2C/;
  • put the B-domain (HOTEL) training set into new-created train_origin and train folders; put the C-domain (UNIV) validation set into val folder; put the C-domain testing set into test folder;
  • train the corresponding DLA model to automatically generate the aligned source data in train folder.

Baseline Models
This repository supports three baseline models: Social-GAN, Trajectron++ and TUTR. Their DADA-modified source code are in ./models/.

Quick Start

To train and evaluate our DADA-model on the A2B task at once, we provide a bash script train.sh for a simplified execution.

bash ./train_DADA.sh -b <baseline_model>  # quickly train
bash ./test_DADA.sh -b <baseline_model>  # quickly evaluate

where <baseline_model> could be sgan, trajectron++ or tutr.
For example:

bash ./train_DADA.sh -b sgan  # quickly train
bash ./test_DADA.sh -b sgan  # quickly evaluate

Detailed Training

Training for DLA

The DLA network could to be trained by:

cd ./DLA/
python train_DLA.py  --subset <task_S2T>

For example:

python train_DLA.py  --subset A2B

After finishing training, the aligned source data will be automatically generated in ./datasets/subset/train/.

Training for Prediction Models

Given that our repository supports three baseline models, here we take the Social-GAN as example.

Training for Baseline
The baseline model is directly trained without DLA data:

cd ./models/sgan/scripts/
python train.py --dataset_name <task_S2T>

Training for DLA
The DLA model is trained with DLA data, so you just need to modify the train_set path to f'../../../datasets/{args.dataset_name}/train' and modify the checkpoint_save path to '../checkpoint/checkpoint_DLA'.

Training for DADA
The DADA model is further embedded an discriminator w.r.t DLA model during training phase:

cd ./models/sgan/scripts/
python train_DADA.py --dataset_name <task_S2T>

You can find the code of the discriminator structure and its training procedure.

Detailed Evaluation

Given that our repository supports three baseline models, here we take the Social-GAN as example.

Pretrained Models
We have included pre-trained models in ./models/sgan/checkpoint/ folder that can be directly used to evaluate models.

You can simply view the DADA evaluation result for A2B task by running:

cd ./models/sgan/scripts/
python evaluate_model.py --dataset_name <task_S2T>

To view the baseline and DLA evaluation result, you just need to modify the checkpoint_load path.

Experimental results

ADE results


FDE results


visualization


About

Dual-Alignment Domain Adaptation for Pedestrian Trajectory Prediction

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.8%
  • Shell 0.2%