This repo re-implement social LSTM and implement a new Spatial Pyramid Social LSTM proposed by us.
This implementation is based on https://github.com/vvanirudh/social-lstm-tf. We fix several bugs and vital errors in their repo. We also upgrade the tensorflow API to 1.8.0.
getPixelCoordinates.m
: the matlab code to transform original ETH dataset topixel_pos.csv
, which is used in our code. This file is based on the referred implementation.pixel_pos.csv
: the data file used by our codetransformed_data.pkl
:pixel_pos.csv
will be transformed in our code and save as transformed_data.pkl
DataLoader.py
: deal with data loading and preprocessgrid.py
: calculate grid or pyramid mask, called bytrain.py
model.py
: IMPORTANT! all model (including social lstm and spatial pyramid social lstm) are defined heresocial_sample.py
: predict/test code, could be called using proper console parameters (usesocial_sample.py --help
to see)social_visualize.py
: to draw predicted graphstrain.py
: train code, could be called using proper console parameters (usetrain.py --help
to see)
This directory contains several prediction plots for "Spatial Pyramid Social LSTM" method. Other method's plot can't be obtained since lab server is under maintenance (explained in our report).
This directory contains model file for "Spatial Pyramid Social LSTM" method. Other method's model can't be obtained since lab server is under maintenance (explained in our report).
- delete everything under
social_lstm/save/
- run
train.py
to train a model - run
social_sample.py
to predict - run
social_visualize.py
to visualize