-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
47271ca
commit 1fc76f5
Showing
189 changed files
with
1,203 additions
and
24,512 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,17 @@ | ||
# JRDB-Traj | ||
JRDB Data Preprocessing and Trajectory Prediction Baselines | ||
|
||
## Prerequisites | ||
Install requirements with `bash requirement.sh`. | ||
|
||
## Repository Overview | ||
The pipeline encompasses three key steps: | ||
The pipeline encompasses four key steps: | ||
|
||
1. `python traj_extractor.py`: This script preprocesses the JRDB dataset, extracting trajectories for further analysis. | ||
2. `bash traj_categorize.sh`: Utilizing the TrajNet++ benchmark, this script categorizes '.csv' files and generates '.ndjson' files for the next step. | ||
1. `bash dataload.sh`: This script preprocesses the JRDB dataset, extracting trajectories for further analysis. | ||
2. `bash preprocess.sh`: Utilizing the TrajNet++ benchmark, this script categorizes '.csv' files and generates '.ndjson' files for the next step. | ||
3. `bash train.sh`: This script train baseline trajectory prediction models using the meticulously prepared data. | ||
4. `bash eval.sh`: This script will generate predictions in JRDB leaderboard format. | ||
|
||
|
||
## Work in Progress | ||
This repository is being updated so stay tuned! | ||
This repository is being updated so please stay tuned! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# After extracting the raw data, you will get the extracted data located at the output_path you set. | ||
jrdb_path="/data2/saeed-data/jrdb/train_dataset/labels/" | ||
jrdb_test_path=".../test_trackings/ | ||
out_path="OUT_tmp" | ||
python train_traj_extractor.py --out_path $out_path --jrdb_path $jrdb_path | ||
python test_traj_extractor.py --out_path $out_path --jrdb_path $jrdb_test_path | ||
# There will also be two temp folders named 'temp' and 'conf_temp', can be removed. | ||
rm -r $out_path/temp $out_path/conf_temp | ||
# Move the extracted data to 'trajnetplusplusdataset/data/raw/'.) | ||
mv $out_path trajnetplusplusdataset/data/raw/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
cd jrdb_baselines | ||
python -m trajnetbaselines.lstm.trajnet_evaluator --path jrdb_traj_with_nan --output OUTPUT_BLOCK/jrdb_traj_with_nan/lstm_social_baseline.pkl |
Binary file not shown.
File renamed without changes.
File renamed without changes.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,10 +15,6 @@ | |
], | ||
license='MIT', | ||
description='Trajnet baselines.', | ||
long_description=open('README.rst').read(), | ||
author='Sven Kreiss', | ||
author_email='[email protected]', | ||
url='https://github.com/svenkreiss/trajnetbaselines', | ||
|
||
install_requires=[ | ||
'numpy', | ||
|
Binary file not shown.
2 changes: 1 addition & 1 deletion
2
train/trajnetbaselines/__init__.py → jrdb_baselines/trajnetbaselines/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = '0.1.0' | ||
|
||
from . import augmentation | ||
from . import lstm | ||
from . import lstm |
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion
2
train/trajnetbaselines/lstm/__init__.py → ...selines/trajnetbaselines/lstm/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+614 Bytes
jrdb_baselines/trajnetbaselines/lstm/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added
BIN
+617 Bytes
jrdb_baselines/trajnetbaselines/lstm/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+2.34 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/data_load_utils.cpython-310.pyc
Binary file not shown.
Binary file added
BIN
+2.33 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/data_load_utils.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+11 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/gridbased_pooling.cpython-310.pyc
Binary file not shown.
Binary file added
BIN
+11.1 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/gridbased_pooling.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+2.16 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/loss.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+1.81 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/modules.cpython-310.pyc
Binary file not shown.
Binary file added
BIN
+1.79 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/modules.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+2.54 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/more_non_gridbased_pooling.cpython-310.pyc
Binary file not shown.
Binary file renamed
BIN
+2.55 KB
...more_non_gridbased_pooling.cpython-37.pyc → ...more_non_gridbased_pooling.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+17.6 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/non_gridbased_pooling.cpython-310.pyc
Binary file not shown.
Binary file renamed
BIN
+20.7 KB
...he__/non_gridbased_pooling.cpython-37.pyc → ...he__/non_gridbased_pooling.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+14.4 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/trainer.cpython-310.pyc
Binary file not shown.
Binary file added
BIN
+14.2 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/trainer.cpython-37.pyc
Binary file not shown.
Binary file added
BIN
+3.46 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/trajnet_evaluator.cpython-310.pyc
Binary file not shown.
Binary file added
BIN
+5.1 KB
jrdb_baselines/trajnetbaselines/lstm/__pycache__/utils.cpython-310.pyc
Binary file not shown.
Binary file renamed
BIN
+5.13 KB
...nes/lstm/__pycache__/utils.cpython-37.pyc → ...nes/lstm/__pycache__/utils.cpython-37.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
|
||
class L2Loss(torch.nn.Module): | ||
"""L2 Loss (deterministic version of PredictionLoss) | ||
This Loss penalizes only the primary trajectories | ||
""" | ||
def __init__(self, keep_batch_dim=False): | ||
super(L2Loss, self).__init__() | ||
self.loss = torch.nn.MSELoss(reduction='none') | ||
self.keep_batch_dim = keep_batch_dim | ||
self.loss_multiplier = 100 | ||
|
||
def col_loss(self, primary, neighbours, batch_split, gamma=2.0): | ||
""" | ||
Penalizes model when primary pedestrian prediction comes close | ||
to the neighbour predictions | ||
primary: Tensor [pred_length, 1, 2] | ||
neighbours: Tensor [pred_length, num_neighbours, 2] | ||
""" | ||
|
||
neighbours[neighbours != neighbours] = -1000 | ||
exponential_loss = 0.0 | ||
for (start, end) in zip(batch_split[:-1], batch_split[1:]): | ||
batch_primary = primary[:, start:start+1] | ||
batch_neigh = neighbours[:, start:end] | ||
distance_to_neigh = torch.norm(batch_neigh - batch_primary, dim=2) | ||
mask_far = (distance_to_neigh < 0.25).detach() | ||
distance_to_neigh = -gamma * distance_to_neigh * mask_far | ||
exponential_loss += distance_to_neigh.exp().sum() | ||
return exponential_loss.sum() | ||
|
||
def forward(self, inputs, targets, batch_split): | ||
## Extract primary pedestrians | ||
targets = targets.transpose(0, 1) | ||
targets = targets[batch_split[:-1]] | ||
targets = targets.transpose(0, 1) | ||
inputs = inputs.transpose(0, 1) | ||
inputs = inputs[batch_split[:-1]] | ||
inputs = inputs.transpose(0, 1) | ||
|
||
mask_gt = ~torch.isnan(targets[:,:,0]) | ||
mask_pred = ~torch.isnan(inputs[:,:,0]) | ||
mask = mask_pred*mask_gt | ||
|
||
loss_vis = self.loss(inputs[mask], targets[mask]) | ||
if inputs[~mask].size(0) == 0: | ||
loss = loss_vis | ||
else: | ||
loss_invis = self.loss(inputs[~mask][:,-1], targets[~mask][:,-1]) | ||
loss_invis = torch.cat((loss_invis.unsqueeze(1),torch.zeros(loss_invis.size(0),2).to(loss_invis.device)),dim=1) | ||
loss = torch.cat((loss_vis, loss_invis),dim=0) | ||
|
||
## Used in variety loss (SGAN) | ||
if self.keep_batch_dim: | ||
return loss.mean(dim=0).mean(dim=1) * self.loss_multiplier | ||
|
||
return torch.mean(loss) * self.loss_multiplier | ||
|
Oops, something went wrong.