Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
Co-authored-by: Yang Gao <[email protected]>
  • Loading branch information
SaeedSaadatnejad and yanggao2000 committed Dec 27, 2023
1 parent 04e2281 commit 5ae0f33
Show file tree
Hide file tree
Showing 22 changed files with 2,437 additions and 0 deletions.
58 changes: 58 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Ignore Python virtual environment
venv/

# Ignore compiled Python files
*.pyc

# Ignore cache and temporary files
__pycache__/
*.pyo
*.swp
*.swo

# Ignore IDE-specific files
.vscode/
.idea/

# Ignore environment-specific files
.env
.env.local
.env.*.local

# Ignore log files
*.log

# Ignore package lock files
pip-lock.txt
poetry.lock

# Ignore generated documentation
docs/_build/

# Ignore test coverage reports
htmlcov/

# Ignore compiled binaries
*.exe
*.dll
*.so
*.dylib

# Ignore database files
*.db

# Ignore generated files
*.pyc
*.pyo
*.pyd
__pycache__/

# Ignore cache files
*.cache

# Ignore system-specific files
.DS_Store
Thumbs.db
*.ndjson
experiments/*
data/*
61 changes: 61 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
<div align="center">
<h1> Social-Transmotion:<br> Promptable Human Trajectory Prediction </h1>
<h3>Saeed Saadatnejad*, Yang Gao*, Kaouther Messaoud, Alexandre Alahi
</h3>


<image src="docs/social-transmotion.png" width="600">
</div>

<div align="center"> <h3> Abstract </h3> </div>
<div align="justify">

Accurate human trajectory prediction is crucial for applications such as autonomous vehicles, robotics, and surveillance systems. Yet, existing models often fail to fully leverage the non-verbal social cues human subconsciously communicate when navigating the space.
To address this, we introduce Social-Transmotion, a generic model that exploits the power of transformers to handle diverse and numerous visual cues, capturing the multi-modal nature of human behavior. We translate the idea of a prompt from Natural Language Processing (NLP) to the task of human trajectory prediction, where a prompt can be a sequence of x-y coordinates on the ground, bounding boxes or body poses. This, in turn, augments trajectory data, leading to enhanced human trajectory prediction.
Our model exhibits flexibility and adaptability by capturing spatiotemporal interactions between pedestrians based on the available visual cues, whether they are poses, bounding boxes, or a combination thereof.
By the masking technique, we ensure our model's effectiveness even when certain visual cues are unavailable, although performance is further boosted with the presence of comprehensive visual data.
</br>


# Getting Started

Install the requirements using `pip`:
```
pip install -r requirements.txt
```

We have conveniently added the preprocessed data to the release section of the repository.
Place the data subdirectory of JTA under `data/jta_all_visual_cues` and the data subdirectory of JRDB under `data/jrdb_2dbox` of the repository.

# Training and Testing

## JTA dataset
You can train the Social-Transmotion model on this dataset using the following command:
```
python train_jta.py --cfg configs/jta_all_visual_cues.yaml --exp_name jta
```


To evaluate the trained model, use the following command:
```
python evaluate_jta.py --ckpt ./experiments/jta/checkpoints/checkpoint.pth.tar --metric ade_fde --modality traj+all
```
Please note that the evaluation modality can be any of `[traj, traj+2dbox, traj+3dpose, traj+2dpose, traj+3dpose+3dbox, traj+all]`.
For the ease of use, we have also provided the trained model in the release section of this repo. In order to use that, you should pass the address of the saved checkpoint via `--ckpt`.

## JRDB dataset
You can train the Social-Transmotion model on this dataset using the following command:
```
python train_jrdb.py --cfg configs/jrdb_2dbox.yaml --exp_name jrdb
```

To evaluate the trained model, use the following command:
```
python evaluate_jrdb.py --ckpt ./experiments/jrdb/checkpoints/checkpoint.pth.tar --metric ade_fde --modality traj+2dbox
```
Please note that the evaluation modality can be one any of `[traj, traj+2dbox]`.
For the ease of use, we have also provided the trained model in the release section of this repo. In order to use that, you should pass the address of the saved checkpoint via `--ckpt`.

# Work in Progress

This repository is work-in-progress and will continue to get updated and improved over the coming months.
Empty file added __init__.py
Empty file.
29 changes: 29 additions & 0 deletions configs/jrdb_2dbox.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
SEED: 0
TRAIN:
batch_size: 16
epochs: 100
num_workers: 0
input_track_size: 9
output_track_size: 12
lr: 0.0001
lr_decay: 1
lr_drop: true
aux_weight: 0.2
val_frequency: 5
optimizer: "adam"
max_grad_norm: 1.0
DATA:
train_datasets:
- jrdb_2dbox
MODEL:
seq_len: 30
token_num: 2
num_layers_local: 6
num_layers_global: 3
num_heads: 4
dim_hidden: 128
dim_feedforward: 1024
type: "transmotion"
eval_single: false
checkpoint: ""
output_scale: 1
29 changes: 29 additions & 0 deletions configs/jta_all_visual_cues.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
SEED: 0
TRAIN:
batch_size: 4
epochs: 50
num_workers: 0
input_track_size: 9
output_track_size: 12
lr: 0.0001
lr_decay: 1
lr_drop: true
aux_weight: 0.2
val_frequency: 5
optimizer: "adam"
max_grad_norm: 1.0
DATA:
train_datasets:
- jta_all_visual_cues
MODEL:
seq_len: 435 # 1*21 + (token_num-1)*9 ,seq length for local-former, 219 for 2d/3d pose, 30 for 2d/3d bb, 21 for baseline, 228 for 3dbox+3dpose
token_num: 47 # number of tokens for local-former, 23 or 2d/3d pose, 2 for 2d/3d bb, 1 for baseline
num_layers_local: 6
num_layers_global: 3
num_heads: 4
dim_hidden: 128
dim_feedforward: 1024
type: "transmotion"
eval_single: false
checkpoint: "" ##checkpoint.pth.tar
output_scale: 1
194 changes: 194 additions & 0 deletions dataset_jrdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import torch
from torch.nn.utils.rnn import pad_sequence

from utils.data import load_data_jta_all_visual_cues, load_data_jrdb_2dbox
from torchvision import transforms

def collate_batch(batch):
joints_list = []
masks_list = []
num_people_list = []
for joints, masks in batch:

joints_list.append(joints)
masks_list.append(masks)
num_people_list.append(torch.zeros(joints.shape[0]))

joints = pad_sequence(joints_list, batch_first=True)
masks = pad_sequence(masks_list, batch_first=True)
padding_mask = pad_sequence(num_people_list, batch_first=True, padding_value=1).bool()

return joints, masks, padding_mask


def batch_process_coords(coords, masks, padding_mask, config, modality_selection='traj+2dbox', training=False, multiperson=True):
joints = coords.to(config["DEVICE"])
masks = masks.to(config["DEVICE"])
in_F = config["TRAIN"]["input_track_size"]

in_joints_pelvis = joints[:,:, (in_F-1):in_F, 0:1, :].clone()
in_joints_pelvis_last = joints[:,:, (in_F-2):(in_F-1), 0:1, :].clone()

joints[:,:,:,0] = joints[:,:,:,0] - joints[:,0:1, (in_F-1):in_F, 0]
joints[:,:,:,1:] = (joints[:,:,:,1:] - joints[:,:,(in_F-1):in_F,1:])*0.25 #rescale for BB

B, N, F, J, K = joints.shape
if not training:
if modality_selection=='traj':
joints[:,:,:,1:]=0
elif modality_selection=='traj+2dbox':
pass
else:
print('modality error')
exit()
else:
# augment JRDB traj
joints[:,:,:,0,:3] = getRandomRotatePoseTransform(config)(joints[:,:,:,0,:3])
joints = joints.transpose(1, 2).reshape(B, F, N*J, K)
in_joints_pelvis = in_joints_pelvis.reshape(B, 1, N, K)
in_joints_pelvis_last = in_joints_pelvis_last.reshape(B, 1, N, K)
masks = masks.transpose(1, 2).reshape(B, F, N*J)

in_F, out_F = config["TRAIN"]["input_track_size"], config["TRAIN"]["output_track_size"]
in_joints = joints[:,:in_F].float()
out_joints = joints[:,in_F:in_F+out_F].float()
in_masks = masks[:,:in_F].float()
out_masks = masks[:,in_F:in_F+out_F].float()


return in_joints, in_masks, out_joints, out_masks, padding_mask.float()

def getRandomRotatePoseTransform(config):
"""
Performs a random rotation about the origin (0, 0, 0)
"""

def do_rotate(pose_seq):
B, F, J, K = pose_seq.shape

angles = torch.deg2rad(torch.rand(B)*360)

rotation_matrix = torch.zeros(B, 3, 3).to(pose_seq.device)

## rotate around z axis (vertical axis)
rotation_matrix[:,0,0] = torch.cos(angles)
rotation_matrix[:,0,1] = -torch.sin(angles)
rotation_matrix[:,1,0] = torch.sin(angles)
rotation_matrix[:,1,1] = torch.cos(angles)
rotation_matrix[:,2,2] = 1

rot_pose = torch.bmm(pose_seq.reshape(B, -1, 3).float(), rotation_matrix)
rot_pose = rot_pose.reshape(pose_seq.shape)
return rot_pose

return transforms.Lambda(lambda x: do_rotate(x))



class MultiPersonTrajPoseDataset(torch.utils.data.Dataset):



def __init__(self, name, split="train", track_size=21, track_cutoff=9, segmented=True,
add_flips=False, frequency=1):

self.name = name
self.split = split
self.track_size = track_size
self.track_cutoff = track_cutoff
self.frequency = frequency

self.initialize()

def load_data(self):
raise NotImplementedError("Dataset load_data() method is not implemented.")

def initialize(self):
self.load_data()

tracks = []
for scene in self.datalist:
for seg, j in enumerate(range(0, len(scene[0][0]) - self.track_size * self.frequency + 1, self.track_size)):
people = []
for person in scene:
start_idx = j
end_idx = start_idx + self.track_size * self.frequency
J_3D_real, J_3D_mask = person[0][start_idx:end_idx:self.frequency], person[1][
start_idx:end_idx:self.frequency]
people.append((J_3D_real, J_3D_mask))
tracks.append(people)
self.datalist = tracks


def __len__(self):
return len(self.datalist)

def __getitem__(self, idx):
scene = self.datalist[idx]

J_3D_real = torch.stack([s[0] for s in scene])
J_3D_mask = torch.stack([s[1] for s in scene])

return J_3D_real, J_3D_mask


class JtaAllVisualCuesDataset(MultiPersonTrajPoseDataset):
def __init__(self, **args):
super(JtaAllVisualCuesDataset, self).__init__("jta_all_visual_cues", frequency=1, **args)

def load_data(self):

self.data = load_data_jta_all_visual_cues(split=self.split)
self.datalist = []
for scene in self.data:
joints, mask = scene
people=[]
for n in range(len(joints)):
people.append((torch.from_numpy(joints[n]),torch.from_numpy(mask[n])))

self.datalist.append(people)

class Jrdb2dboxDataset(MultiPersonTrajPoseDataset):
def __init__(self, **args):
super(Jrdb2dboxDataset, self).__init__("jrdb_2dbox", frequency=1, **args)

def load_data(self):

self.data = load_data_jrdb_2dbox(split=self.split)
self.datalist = []
for scene in self.data:
joints, mask = scene
people=[]
for n in range(len(joints)):
people.append((torch.from_numpy(joints[n]),torch.from_numpy(mask[n])))

self.datalist.append(people)


def create_dataset(dataset_name, logger, **args):
logger.info("Loading dataset " + dataset_name)

if dataset_name == 'jta_all_visual_cues':
dataset = JtaAllVisualCuesDataset(**args)
elif dataset_name == 'jrdb_2dbox':
dataset = Jrdb2dboxDataset(**args)
else:
raise ValueError(f"Dataset with name '{dataset_name}' not found.")

return dataset


def get_datasets(datasets_list, config, logger):

in_F, out_F = config['TRAIN']['input_track_size'], config['TRAIN']['output_track_size']
datasets = []
for dataset_name in datasets_list:
datasets.append(create_dataset(dataset_name, logger, split="train", track_size=(in_F+out_F), track_cutoff=in_F))
return datasets







Loading

0 comments on commit 5ae0f33

Please sign in to comment.