Skip to content

Latest commit

 

History

History
77 lines (61 loc) · 3.17 KB

README.md

File metadata and controls

77 lines (61 loc) · 3.17 KB

Temporal Shift Module

TSM-R50 from "TSM: Temporal Shift Module for Efficient Video Understanding" https://arxiv.org/abs/1811.08383

TSM is a widely used Action Recognition model. This TensorRT implementation is tested with TensorRT 5.1 and TensorRT 7.2.

For the PyTorch implementation, you can refer to open-mmlab/mmaction2 or mit-han-lab/temporal-shift-module.

More details about the shift module(which is the core of TSM) could to test_shift.py.

Tutorial

  • An example could refer to demo.sh

    • Requirements: Successfully installed torch>=1.3.0, torchvision
  • Step 1: Train/Download TSM-R50 checkpoints from offical Github repo or MMAction2

    • Supported settings: num_segments, shift_div, num_classes.
    • Fixed settings: backbone(ResNet50), shift_place(blockres), temporal_pool(False).
  • Step 2: Convert PyTorch checkpoints to TensorRT weights.

python gen_wts.py /path/to/pytorch.pth --out-filename /path/to/tensorrt.wts
  • Step 3: Test Python API.
    • Modify configs in tsm_r50.py.
    • Inference with tsm_r50.py.
# Supported settings
BATCH_SIZE = 1
NUM_SEGMENTS = 8
INPUT_H = 224
INPUT_W = 224
OUTPUT_SIZE = 400
SHIFT_DIV = 8
usage: tsm_r50.py [-h] [--tensorrt-weights TENSORRT_WEIGHTS] [--input-video INPUT_VIDEO] [--save-engine-path SAVE_ENGINE_PATH] [--load-engine-path LOAD_ENGINE_PATH] [--test-mmaction2] [--mmaction2-config MMACTION2_CONFIG] [--mmaction2-checkpoint MMACTION2_CHECKPOINT] [--test-cpp] [--cpp-result-path CPP_RESULT_PATH]

optional arguments:
  -h, --help            show this help message and exit
  --tensorrt-weights TENSORRT_WEIGHTS
                        Path to TensorRT weights, which is generated by gen_weights.py
  --input-video INPUT_VIDEO
                        Path to local video file
  --save-engine-path SAVE_ENGINE_PATH
                        Save engine to local file
  --load-engine-path LOAD_ENGINE_PATH
                        Saved engine file path
  --test-mmaction2      Compare TensorRT results with MMAction2 Results
  --mmaction2-config MMACTION2_CONFIG
                        Path to MMAction2 config file
  --mmaction2-checkpoint MMACTION2_CHECKPOINT
                        Path to MMAction2 checkpoint url or file path
  --test-cpp            Compare Python API results with C++ API results
  --cpp-result-path CPP_RESULT_PATH
                        Path to C++ API results
  • Step 4: Test C++ API.
    • Mocify Configs in tsm_r50.cpp.
    • Build from source code: mkdir build && cd build && cmake .. && make
    • Generate Engine file: ./tsm_r50 -s
    • Inference with genrated engine file and write predictions to local: ./tsm_r50 -d
    • Compare results with Python API: python tsm_r50.py --tensorrt-weights /path/to/tensorrt.weights --test-cpp --cpp-result-file /path/to/cpp-result.txt

TODO

  • Python Shift module.
  • Generate wts of official tsm and mmaction2 tsm.
  • Python API Definition
  • Test with mmaction2 demo
  • Tutorial
  • C++ API Definition