-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ddp * readme change * lint * code block * readme
- Loading branch information
Showing
9 changed files
with
254 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""5. DistributedDataParallel (DDP) Framework | ||
======================================================= | ||
Training deep neural networks on videos is very time consuming. | ||
For example, training a state-of-the-art SlowFast network on Kinetics400 dataset (with 240K 10-seconds short videos) | ||
using a server with 8 V100 GPUs takes more than 10 days. | ||
Slow training causes long research cycles and is not friendly for new comers and students to work on video related problems. | ||
Using distributed training is a natural choice. | ||
Spreading the huge computation over multiple machines can speed up training a lot. | ||
However, only a few open sourced Github repositories on video understanding support distributed training, | ||
and they often lack documentation for this feature. | ||
Besides, there is not much information/tutorial online on how to perform distributed training for deep video models. | ||
Hence, we provide a simple tutorial here to demonstrate how to use our DistributedDataParallel (DDP) framework to perform | ||
efficient distributed training. Note that, even in a single instance with multiple GPUs, | ||
DDP should be used and is much more efficient that vanilla dataparallel. | ||
""" | ||
|
||
######################################################################## | ||
# Distributed training | ||
# -------------------- | ||
# | ||
# There are two ways in which we can distribute the workload of training a neural network across multiple devices, | ||
# data parallelism and model parallelism. Data parallelism refers to the case where each device stores a complete copy of the model. | ||
# Each device works with a different part of the dataset, and the devices collectively update a shared model. | ||
# When models are so large that they don't fit into device memory, then model parallelism is useful. | ||
# Here, different devices are assigned the task of learning different parts of the model. | ||
# In this tutorial, we describe how to train a model with devices distributed across machines in a data parallel way. | ||
# To be specific, we adopt DistributedDataParallel (DDP), which implements data parallelism at the module level that can be applied | ||
# across multiple machines. DDP spawn multiple processes and create a single GPU instance per process. | ||
# It can spread the computation more evenly and particular be useful for deep video model training. | ||
# | ||
# In order to keep this tutorial concise, I wouldn't go into details of what is DDP. | ||
# Readers can refer to Pytorch `Official Tutorials <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ for more information. | ||
|
||
|
||
######################################################################## | ||
# How to use our DDP framework? | ||
# ---------------------------------------------------------------------- | ||
# | ||
# In order to perform distributed training, you need to (1) prepare the cluster; (2) prepare environment; and (3) prepare your code | ||
# and data. | ||
|
||
################################################################ | ||
# We need a cluster that each node can communicate with each other. | ||
# The first step is to generate ssh keys for each machine. | ||
# For better illustration, let's assume we have 2 machines, node1 and node2. | ||
# | ||
# First, ssh into node1 and type | ||
# :: | ||
# | ||
# ssh-keygen -t rsa | ||
# | ||
# Just follow the default, you will have a file named ``id_rsa`` and a file named ``id_rsa.pub``, both under the ``~/.ssh/`` folder. | ||
# ``id_rsa`` is the private RSA key, and ``id_rsa.pub`` is its public key. | ||
|
||
################################################################ | ||
# Second, copy both files (``id_rsa`` and ``id_rsa.pub``) of node1 to all other machines. | ||
# For each machine, you will find an ``authorized_keys`` file under ``~/.ssh/`` folder as well. | ||
# Append ``authorized_keys`` with the content of ``id_rsa.pub`` | ||
# This step will make sure all the machines in the cluster is able to communicate with each other. | ||
|
||
################################################################ | ||
# Before moving on to next step, it is better to perform some sanity checks to make sure the communication is good. | ||
# For example, if you can successfully ssh into other machines, it means they can communicate with each other now. You are good to go. | ||
# If there is any error during ssh, you can use option ``-vvv`` to get verbose information for debugging. | ||
|
||
################################################################ | ||
# Once you get the cluster ready, it is time to prepare the enviroment. | ||
# Please check `GluonCV installation guide <https://gluon-cv.mxnet.io/install/install-more.html>`_ for more information. | ||
# Every machine should have the same enviroment, such as CUDA, PyTorch and GluonCV, so that the code is runnable. | ||
|
||
################################################################ | ||
# Now it is time to prepare your code and data in each node. | ||
# In terms of code change, you only need to modify the `DDP_CONFIG` part in the yaml configuration file. | ||
# For example, in our case we have 2 nodes, then change `WORLD_SIZE` to 2. | ||
# `WOLRD_URLS` contains all machines' IP used for training (only support LAN IP for now), you can put their IP addresses in the list. | ||
# If `AUTO_RANK_MATCH` is True, the launcher will automatically assign a world rank number to each machine in `WOLRD_URLS`, | ||
# and consider the first machine as the root. Please make sure to use root's IP for `DIST_URL`. | ||
# If `AUTO_RANK_MATCH` is False, you need to manually set a ranking number to each instance. | ||
# The instance assigned with `rank=0` will be considered as the root machine. | ||
# We suggest always enable `AUTO_RANK_MATCH`. | ||
# An example configuration look like below, | ||
# :: | ||
# | ||
# DDP_CONFIG: | ||
# AUTO_RANK_MATCH: True | ||
# WORLD_SIZE: 2 # Total Number of machines | ||
# WORLD_RANK: 0 # Rank of this machine | ||
# DIST_URL: 'tcp://172.31.72.195:23456' | ||
# WOLRD_URLS: ['172.31.72.195', '172.31.72.196'] | ||
# | ||
# GPU_WORLD_SIZE: 8 # Total Number of GPUs, will be assigned automatically | ||
# GPU_WORLD_RANK: 0 # Rank of GPUs, will be assigned automatically | ||
# DIST_BACKEND: 'nccl' | ||
# GPU: 0 # Rank of GPUs in the machine, will be assigned automatically | ||
# DISTRIBUTED: True | ||
|
||
################################################################ | ||
# Once it is done, you can kickstart the training on each machine. | ||
# Simply run `train_ddp_pytorch.py/test_ddp_pytorch.py` with the desire configuration file on each instance, e.g., | ||
# :: | ||
# | ||
# python train_ddp_pytorch.py --config-file XXX.yaml | ||
# | ||
# If you are using multiple instances for training, we suggest you start running on the root instance firstly | ||
# and then start the code on other instances. | ||
# The log will only be shown on the root instance by default. | ||
|
||
################################################################ | ||
# In the end, we want to point out that we have integrated dataloader and training/testing loop in our DDP framework. | ||
# If you simply want to try out our model zoo on your dataset/usecase, please see previous tutorial on how to finetune. | ||
# If you have your new video model, you can add it to the model zoo (e.g., a single .py file) and enjoy the speed up brought by our DDP framework. | ||
# You don't need to handle the multiprocess dataloading and the underlying distributed training setup. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Action Recognition[1] | ||
[GluonCV Model Zoo](https://gluon-cv.mxnet.io/model_zoo/action_recognition.html) | ||
|
||
## Inference/Calibration Tutorial | ||
|
||
### FP32 inference | ||
|
||
``` | ||
export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'` | ||
export OMP_NUM_THREADS=${CPUs} | ||
export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 | ||
# dummy data | ||
python test_recognizer.py --model inceptionv3_ucf101 --use-pretrained --mode hybrid --input-size 299 --new-height 340 --new-width 450 --num-segments 3 --batch-size 64 --benchmark | ||
# real data | ||
python test_recognizer.py --model inceptionv3_ucf101 --use-pretrained --mode hybrid --input-size 299 --new-height 340 --new-width 450 --num-segments 3 --batch-size 64 | ||
``` | ||
|
||
### Calibration | ||
|
||
In naive mode, FP32 models are calibrated by using 5 mini-batches of data (32 images per batch). Quantized models will be saved into `./model/`. | ||
|
||
``` | ||
# ucf101 dataset | ||
python test_recognizer.py --model inceptionv3_ucf101 --new-height 340 --new-width 450 --input-size 299 --num-segments 3 --use-pretrained --calibration | ||
# kinetics400 dataset | ||
python test_recognizer.py --dataset kinetics400 --data-dir path/to/datasets --model resnet18_v1b_kinetics400 --use-pretrained --num-classes 400 --new-height 256 --new-width 340 --input-size 224 --num-segments 7 --calibration | ||
``` | ||
|
||
### INT8 Inference | ||
|
||
``` | ||
export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'` | ||
export OMP_NUM_THREADS=${CPUs} | ||
export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 | ||
# dummy data | ||
python test_recognizer.py --model inceptionv3_ucf101 --mode hybrid --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --quantized --benchmark | ||
# real data | ||
python test_recognizer.py --model inceptionv3_ucf101 --mode hybrid --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --quantized | ||
# deploy static model | ||
python test_recognizer.py --model inceptionv3_ucf101 --deploy --model-prefix ./model/inceptionv3_ucf101-quantized-naive --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --benchmark | ||
``` | ||
|
||
Users are also recommended to bind processes to specific cores via `numactl` for better performance, like below: | ||
|
||
``` | ||
numactl --physcpubind=0-27 --membind=0 python test_recognizer.py ... | ||
``` | ||
|
||
## Performance | ||
Below results are collected based on Intel(R) VNNI enabled C5.12xlarge with 24 physical cores. | ||
|
||
|model | fp32 Top-1 | int8 Top-1 | | ||
|-- | -- | -- | | ||
inceptionv3_ucf101 |86.92 | 86.55 | | ||
vgg16_ucf101 |81.86 | 81.41 | | ||
resnet18_v1b_kinetics400 |63.29 | 63.14 | | ||
resnet50_v1b_kinetics400 |68.08 | 68.15 | | ||
inceptionv3_kinetics400 |67.93 | 67.92 | | ||
|
||
## References | ||
|
||
1. Limin Wang, Yuanjun Xiong, Zhe Wang and Yu Qiao. “Towards Good Practices for Very Deep Two-Stream ConvNets.” arXiv preprint arXiv:1507.02159, 2015. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,72 @@ | ||
# Action Recognition[1] | ||
[GluonCV Model Zoo](https://gluon-cv.mxnet.io/model_zoo/action_recognition.html) | ||
|
||
## Inference/Calibration Tutorial | ||
## PyTorch Tutorial | ||
|
||
### FP32 inference | ||
### [How to train?](https://cv.gluon.ai/build/examples_torch_action_recognition/finetune_custom.html) | ||
|
||
``` | ||
python train_ddp_pytorch.py --config-file ./configuration/XXX.yaml | ||
``` | ||
|
||
export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'` | ||
export OMP_NUM_THREADS=${CPUs} | ||
export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 | ||
# dummy data | ||
python test_recognizer.py --model inceptionv3_ucf101 --use-pretrained --mode hybrid --input-size 299 --new-height 340 --new-width 450 --num-segments 3 --batch-size 64 --benchmark | ||
# real data | ||
python test_recognizer.py --model inceptionv3_ucf101 --use-pretrained --mode hybrid --input-size 299 --new-height 340 --new-width 450 --num-segments 3 --batch-size 64 | ||
If multi-grid training is needed, | ||
``` | ||
python train_ddp_shortonly_pytorch.py --config-file ./configuration/XXX.yaml | ||
``` | ||
Note that we only use short-cycle here because it is stable and applies to a range of models. | ||
|
||
### Calibration | ||
|
||
In naive mode, FP32 models are calibrated by using 5 mini-batches of data (32 images per batch). Quantized models will be saved into `./model/`. | ||
### [How to evaluate?](https://cv.gluon.ai/build/examples_torch_action_recognition/demo_i3d_kinetics400.html) | ||
|
||
``` | ||
# ucf101 dataset | ||
python test_recognizer.py --model inceptionv3_ucf101 --new-height 340 --new-width 450 --input-size 299 --num-segments 3 --use-pretrained --calibration | ||
# Change PRETRAINED to True if using our pretraind model zoo | ||
python test_ddp_pytorch.py --config-file ./configuration/XXX.yaml | ||
``` | ||
|
||
### [How to extract features?](https://cv.gluon.ai/build/examples_torch_action_recognition/extract_feat.html) | ||
|
||
# kinetics400 dataset | ||
python test_recognizer.py --dataset kinetics400 --data-dir path/to/datasets --model resnet18_v1b_kinetics400 --use-pretrained --num-classes 400 --new-height 256 --new-width 340 --input-size 224 --num-segments 7 --calibration | ||
``` | ||
python feat_extract_pytorch.py --config-file ./configuration/XXX.yaml | ||
``` | ||
|
||
### INT8 Inference | ||
### [How to get speed measurement?](https://cv.gluon.ai/build/examples_torch_action_recognition/speed.html) | ||
|
||
``` | ||
python get_flops.py --config-file ./configuration/XXX.yaml | ||
python get_fps.py --config-file ./configuration/XXX.yaml | ||
``` | ||
|
||
export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'` | ||
export OMP_NUM_THREADS=${CPUs} | ||
export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 | ||
|
||
# dummy data | ||
python test_recognizer.py --model inceptionv3_ucf101 --mode hybrid --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --quantized --benchmark | ||
## MXNet Tutorial | ||
|
||
# real data | ||
python test_recognizer.py --model inceptionv3_ucf101 --mode hybrid --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --quantized | ||
|
||
# deploy static model | ||
python test_recognizer.py --model inceptionv3_ucf101 --deploy --model-prefix ./model/inceptionv3_ucf101-quantized-naive --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --benchmark | ||
### [How to train?](https://cv.gluon.ai/build/examples_action_recognition/dive_deep_i3d_kinetics400.html) | ||
MXNet codebase adopts argparser, hence requiring many arguments. Please check [model zoo page](https://cv.gluon.ai/model_zoo/action_recognition.html) for detailed training command. | ||
|
||
``` | ||
python train_recognizer.py | ||
``` | ||
|
||
Users are also recommended to bind processes to specific cores via `numactl` for better performance, like below: | ||
### [How to evaluate?](https://cv.gluon.ai/build/examples_action_recognition/demo_i3d_kinetics400.html) | ||
|
||
``` | ||
numactl --physcpubind=0-27 --membind=0 python test_recognizer.py ... | ||
python test_recognizer.py | ||
``` | ||
|
||
## Performance | ||
Below results are collected based on Intel(R) VNNI enabled C5.12xlarge with 24 physical cores. | ||
### [How to extract features?](https://cv.gluon.ai/build/examples_action_recognition/feat_custom.html) | ||
|
||
``` | ||
python feat_extract.py | ||
``` | ||
|
||
|model | fp32 Top-1 | int8 Top-1 | | ||
|-- | -- | -- | | ||
inceptionv3_ucf101 |86.92 | 86.55 | | ||
vgg16_ucf101 |81.86 | 81.41 | | ||
resnet18_v1b_kinetics400 |63.29 | 63.14 | | ||
resnet50_v1b_kinetics400 |68.08 | 68.15 | | ||
inceptionv3_kinetics400 |67.93 | 67.92 | | ||
### [How to do inference on your own video?](https://cv.gluon.ai/build/examples_action_recognition/demo_custom.html) | ||
|
||
``` | ||
python inference.py | ||
``` | ||
|
||
## References | ||
## MXNet calibration | ||
Please check out [CALIBRATION.md](https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/CALIBRATION.md) for more information on INT8 model calibration and inference. | ||
|
||
1. Limin Wang, Yuanjun Xiong, Zhe Wang and Yu Qiao. “Towards Good Practices for Very Deep Two-Stream ConvNets.” arXiv preprint arXiv:1507.02159, 2015. | ||
## Reproducing our arXiv survey paper | ||
Please check out [ARXIV.md](https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/ARXIV.md) for more information on how to get the same dataset and how to reproduce all the methods in our model zoo. |
Oops, something went wrong.