From 18913750053cd40b6277e8e821ff35f5fec34cce Mon Sep 17 00:00:00 2001 From: Yi Zhu Date: Tue, 1 Dec 2020 23:02:59 -0800 Subject: [PATCH] DDP tutorial (#1553) * ddp * readme change * lint * code block * readme --- README.md | 9 +- docs/build.yml | 9 +- docs/conf.py | 2 +- .../action_recognition/ddp_pytorch.py | 116 ++++++++++++++++++ .../demo_i3d_kinetics400.py | 3 +- docs/tutorials_torch/index.rst | 6 + scripts/action-recognition/CALIBRATION.md | 71 +++++++++++ scripts/action-recognition/README.md | 81 ++++++------ .../i3d_resnet50_v1_kinetics400.yaml | 2 +- 9 files changed, 254 insertions(+), 45 deletions(-) create mode 100644 docs/tutorials_torch/action_recognition/ddp_pytorch.py create mode 100644 scripts/action-recognition/CALIBRATION.md diff --git a/README.md b/README.md index 3c03c8a2e4..7f0caed0ce 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ The following commands install the stable version of GluonCV and MXNet: ```bash pip install gluoncv --upgrade +# native pip install -U --pre mxnet -f https://dist.mxnet.io/python/mkl # cuda 10.2 pip install -U --pre mxnet -f https://dist.mxnet.io/python/cu102mkl @@ -80,6 +81,7 @@ You may get access to latest features and bug fixes with the following commands ```bash pip install gluoncv --pre --upgrade +# native pip install -U --pre mxnet -f https://dist.mxnet.io/python/mkl # cuda 10.2 pip install -U --pre mxnet -f https://dist.mxnet.io/python/cu102mkl @@ -97,6 +99,8 @@ The following commands install the stable version of GluonCV and PyTorch: ```bash pip install gluoncv --upgrade +# native +pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html # cuda 10.2 pip install torch==1.6.0 torchvision==0.7.0 ``` @@ -111,7 +115,10 @@ You may get access to latest features and bug fixes with the following commands ```bash pip install gluoncv --pre --upgrade -pip install torch==1.6.0 torchvision==0.7.0 # for cuda 10.2 +# native +pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html +# cuda 10.2 +pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html ``` diff --git a/docs/build.yml b/docs/build.yml index bd2b3a2352..dbda4f49c8 100644 --- a/docs/build.yml +++ b/docs/build.yml @@ -1,4 +1,8 @@ name: gluon_vision_docs +channels: + - pytorch + - conda-forge + - defaults dependencies: - python=3.7 - sphinx>=1.5.5 @@ -6,7 +10,9 @@ dependencies: - numpy - matplotlib - sphinx_rtd_theme -- pip=19.1.1 +- pip=20.2 +- pytorch=1.6.0 +- torchvision=0.7.0 - pip: - https://github.com/mli/mx-theme/tarball/0.3.1 - sphinx-gallery @@ -24,3 +30,4 @@ dependencies: - cython - pycocotools - autocfg + - yacs diff --git a/docs/conf.py b/docs/conf.py index 55064e9acf..63f54d5477 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -89,7 +89,7 @@ 'build/examples_deployment', 'build/examples_torch_action_recognition'], - 'filename_pattern': '.pydisabled', + 'filename_pattern': '.py', 'ignore_pattern': 'im2rec.py', 'expected_failing_examples': [], diff --git a/docs/tutorials_torch/action_recognition/ddp_pytorch.py b/docs/tutorials_torch/action_recognition/ddp_pytorch.py new file mode 100644 index 0000000000..14f4650f48 --- /dev/null +++ b/docs/tutorials_torch/action_recognition/ddp_pytorch.py @@ -0,0 +1,116 @@ +"""5. DistributedDataParallel (DDP) Framework +======================================================= + +Training deep neural networks on videos is very time consuming. +For example, training a state-of-the-art SlowFast network on Kinetics400 dataset (with 240K 10-seconds short videos) +using a server with 8 V100 GPUs takes more than 10 days. +Slow training causes long research cycles and is not friendly for new comers and students to work on video related problems. +Using distributed training is a natural choice. +Spreading the huge computation over multiple machines can speed up training a lot. +However, only a few open sourced Github repositories on video understanding support distributed training, +and they often lack documentation for this feature. +Besides, there is not much information/tutorial online on how to perform distributed training for deep video models. + +Hence, we provide a simple tutorial here to demonstrate how to use our DistributedDataParallel (DDP) framework to perform +efficient distributed training. Note that, even in a single instance with multiple GPUs, +DDP should be used and is much more efficient that vanilla dataparallel. + + +""" + +######################################################################## +# Distributed training +# -------------------- +# +# There are two ways in which we can distribute the workload of training a neural network across multiple devices, +# data parallelism and model parallelism. Data parallelism refers to the case where each device stores a complete copy of the model. +# Each device works with a different part of the dataset, and the devices collectively update a shared model. +# When models are so large that they don't fit into device memory, then model parallelism is useful. +# Here, different devices are assigned the task of learning different parts of the model. +# In this tutorial, we describe how to train a model with devices distributed across machines in a data parallel way. +# To be specific, we adopt DistributedDataParallel (DDP), which implements data parallelism at the module level that can be applied +# across multiple machines. DDP spawn multiple processes and create a single GPU instance per process. +# It can spread the computation more evenly and particular be useful for deep video model training. +# +# In order to keep this tutorial concise, I wouldn't go into details of what is DDP. +# Readers can refer to Pytorch `Official Tutorials `_ for more information. + + +######################################################################## +# How to use our DDP framework? +# ---------------------------------------------------------------------- +# +# In order to perform distributed training, you need to (1) prepare the cluster; (2) prepare environment; and (3) prepare your code +# and data. + +################################################################ +# We need a cluster that each node can communicate with each other. +# The first step is to generate ssh keys for each machine. +# For better illustration, let's assume we have 2 machines, node1 and node2. +# +# First, ssh into node1 and type +# :: +# +# ssh-keygen -t rsa +# +# Just follow the default, you will have a file named ``id_rsa`` and a file named ``id_rsa.pub``, both under the ``~/.ssh/`` folder. +# ``id_rsa`` is the private RSA key, and ``id_rsa.pub`` is its public key. + +################################################################ +# Second, copy both files (``id_rsa`` and ``id_rsa.pub``) of node1 to all other machines. +# For each machine, you will find an ``authorized_keys`` file under ``~/.ssh/`` folder as well. +# Append ``authorized_keys`` with the content of ``id_rsa.pub`` +# This step will make sure all the machines in the cluster is able to communicate with each other. + +################################################################ +# Before moving on to next step, it is better to perform some sanity checks to make sure the communication is good. +# For example, if you can successfully ssh into other machines, it means they can communicate with each other now. You are good to go. +# If there is any error during ssh, you can use option ``-vvv`` to get verbose information for debugging. + +################################################################ +# Once you get the cluster ready, it is time to prepare the enviroment. +# Please check `GluonCV installation guide `_ for more information. +# Every machine should have the same enviroment, such as CUDA, PyTorch and GluonCV, so that the code is runnable. + +################################################################ +# Now it is time to prepare your code and data in each node. +# In terms of code change, you only need to modify the `DDP_CONFIG` part in the yaml configuration file. +# For example, in our case we have 2 nodes, then change `WORLD_SIZE` to 2. +# `WOLRD_URLS` contains all machines' IP used for training (only support LAN IP for now), you can put their IP addresses in the list. +# If `AUTO_RANK_MATCH` is True, the launcher will automatically assign a world rank number to each machine in `WOLRD_URLS`, +# and consider the first machine as the root. Please make sure to use root's IP for `DIST_URL`. +# If `AUTO_RANK_MATCH` is False, you need to manually set a ranking number to each instance. +# The instance assigned with `rank=0` will be considered as the root machine. +# We suggest always enable `AUTO_RANK_MATCH`. +# An example configuration look like below, +# :: +# +# DDP_CONFIG: +# AUTO_RANK_MATCH: True +# WORLD_SIZE: 2 # Total Number of machines +# WORLD_RANK: 0 # Rank of this machine +# DIST_URL: 'tcp://172.31.72.195:23456' +# WOLRD_URLS: ['172.31.72.195', '172.31.72.196'] +# +# GPU_WORLD_SIZE: 8 # Total Number of GPUs, will be assigned automatically +# GPU_WORLD_RANK: 0 # Rank of GPUs, will be assigned automatically +# DIST_BACKEND: 'nccl' +# GPU: 0 # Rank of GPUs in the machine, will be assigned automatically +# DISTRIBUTED: True + +################################################################ +# Once it is done, you can kickstart the training on each machine. +# Simply run `train_ddp_pytorch.py/test_ddp_pytorch.py` with the desire configuration file on each instance, e.g., +# :: +# +# python train_ddp_pytorch.py --config-file XXX.yaml +# +# If you are using multiple instances for training, we suggest you start running on the root instance firstly +# and then start the code on other instances. +# The log will only be shown on the root instance by default. + +################################################################ +# In the end, we want to point out that we have integrated dataloader and training/testing loop in our DDP framework. +# If you simply want to try out our model zoo on your dataset/usecase, please see previous tutorial on how to finetune. +# If you have your new video model, you can add it to the model zoo (e.g., a single .py file) and enjoy the speed up brought by our DDP framework. +# You don't need to handle the multiprocess dataloading and the underlying distributed training setup. diff --git a/docs/tutorials_torch/action_recognition/demo_i3d_kinetics400.py b/docs/tutorials_torch/action_recognition/demo_i3d_kinetics400.py index 9c02246729..7f7bf19abe 100644 --- a/docs/tutorials_torch/action_recognition/demo_i3d_kinetics400.py +++ b/docs/tutorials_torch/action_recognition/demo_i3d_kinetics400.py @@ -71,10 +71,11 @@ # Next, we load a pre-trained I3D model. Make sure to change the ``pretrained`` in the configuration file to True. -config_file = './scripts/action-recognition/configuration/i3d_resnet50_v1_kinetics400.yaml' +config_file = '../../../scripts/action-recognition/configuration/i3d_resnet50_v1_kinetics400.yaml' cfg = get_cfg_defaults() cfg.merge_from_file(config_file) model = get_model(cfg) +model.eval() print('%s model is successfully loaded.' % cfg.CONFIG.MODEL.NAME) diff --git a/docs/tutorials_torch/index.rst b/docs/tutorials_torch/index.rst index 6bdd24fd94..cf1c3cf3f7 100644 --- a/docs/tutorials_torch/index.rst +++ b/docs/tutorials_torch/index.rst @@ -33,6 +33,12 @@ Action Recognition How to compute FLOPS, number of parameters, latency and fps of a video model + .. card:: + :title: DistributedDataParallel (DDP) framework + :link: ../build/examples_torch_action_recognition/ddp_pytorch.html + + How to use our DistributedDataParallel framework + .. toctree:: :hidden: diff --git a/scripts/action-recognition/CALIBRATION.md b/scripts/action-recognition/CALIBRATION.md new file mode 100644 index 0000000000..0db57e1b25 --- /dev/null +++ b/scripts/action-recognition/CALIBRATION.md @@ -0,0 +1,71 @@ +# Action Recognition[1] +[GluonCV Model Zoo](https://gluon-cv.mxnet.io/model_zoo/action_recognition.html) + +## Inference/Calibration Tutorial + +### FP32 inference + +``` + +export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'` +export OMP_NUM_THREADS=${CPUs} +export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 + +# dummy data +python test_recognizer.py --model inceptionv3_ucf101 --use-pretrained --mode hybrid --input-size 299 --new-height 340 --new-width 450 --num-segments 3 --batch-size 64 --benchmark + +# real data +python test_recognizer.py --model inceptionv3_ucf101 --use-pretrained --mode hybrid --input-size 299 --new-height 340 --new-width 450 --num-segments 3 --batch-size 64 +``` + +### Calibration + +In naive mode, FP32 models are calibrated by using 5 mini-batches of data (32 images per batch). Quantized models will be saved into `./model/`. + +``` +# ucf101 dataset +python test_recognizer.py --model inceptionv3_ucf101 --new-height 340 --new-width 450 --input-size 299 --num-segments 3 --use-pretrained --calibration + +# kinetics400 dataset +python test_recognizer.py --dataset kinetics400 --data-dir path/to/datasets --model resnet18_v1b_kinetics400 --use-pretrained --num-classes 400 --new-height 256 --new-width 340 --input-size 224 --num-segments 7 --calibration +``` + +### INT8 Inference + +``` + +export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'` +export OMP_NUM_THREADS=${CPUs} +export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 + +# dummy data +python test_recognizer.py --model inceptionv3_ucf101 --mode hybrid --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --quantized --benchmark + +# real data +python test_recognizer.py --model inceptionv3_ucf101 --mode hybrid --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --quantized + +# deploy static model +python test_recognizer.py --model inceptionv3_ucf101 --deploy --model-prefix ./model/inceptionv3_ucf101-quantized-naive --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --benchmark + +``` + +Users are also recommended to bind processes to specific cores via `numactl` for better performance, like below: + +``` +numactl --physcpubind=0-27 --membind=0 python test_recognizer.py ... +``` + +## Performance +Below results are collected based on Intel(R) VNNI enabled C5.12xlarge with 24 physical cores. + +|model | fp32 Top-1 | int8 Top-1 | +|-- | -- | -- | +inceptionv3_ucf101 |86.92 | 86.55 | +vgg16_ucf101 |81.86 | 81.41 | +resnet18_v1b_kinetics400 |63.29 | 63.14 | +resnet50_v1b_kinetics400 |68.08 | 68.15 | +inceptionv3_kinetics400 |67.93 | 67.92 | + +## References + +1. Limin Wang, Yuanjun Xiong, Zhe Wang and Yu Qiao. “Towards Good Practices for Very Deep Two-Stream ConvNets.” arXiv preprint arXiv:1507.02159, 2015. diff --git a/scripts/action-recognition/README.md b/scripts/action-recognition/README.md index 0db57e1b25..c46d94c943 100644 --- a/scripts/action-recognition/README.md +++ b/scripts/action-recognition/README.md @@ -1,71 +1,72 @@ # Action Recognition[1] [GluonCV Model Zoo](https://gluon-cv.mxnet.io/model_zoo/action_recognition.html) -## Inference/Calibration Tutorial +## PyTorch Tutorial -### FP32 inference +### [How to train?](https://cv.gluon.ai/build/examples_torch_action_recognition/finetune_custom.html) +``` +python train_ddp_pytorch.py --config-file ./configuration/XXX.yaml ``` -export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'` -export OMP_NUM_THREADS=${CPUs} -export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 - -# dummy data -python test_recognizer.py --model inceptionv3_ucf101 --use-pretrained --mode hybrid --input-size 299 --new-height 340 --new-width 450 --num-segments 3 --batch-size 64 --benchmark - -# real data -python test_recognizer.py --model inceptionv3_ucf101 --use-pretrained --mode hybrid --input-size 299 --new-height 340 --new-width 450 --num-segments 3 --batch-size 64 +If multi-grid training is needed, +``` +python train_ddp_shortonly_pytorch.py --config-file ./configuration/XXX.yaml ``` +Note that we only use short-cycle here because it is stable and applies to a range of models. -### Calibration -In naive mode, FP32 models are calibrated by using 5 mini-batches of data (32 images per batch). Quantized models will be saved into `./model/`. +### [How to evaluate?](https://cv.gluon.ai/build/examples_torch_action_recognition/demo_i3d_kinetics400.html) ``` -# ucf101 dataset -python test_recognizer.py --model inceptionv3_ucf101 --new-height 340 --new-width 450 --input-size 299 --num-segments 3 --use-pretrained --calibration +# Change PRETRAINED to True if using our pretraind model zoo +python test_ddp_pytorch.py --config-file ./configuration/XXX.yaml +``` + +### [How to extract features?](https://cv.gluon.ai/build/examples_torch_action_recognition/extract_feat.html) -# kinetics400 dataset -python test_recognizer.py --dataset kinetics400 --data-dir path/to/datasets --model resnet18_v1b_kinetics400 --use-pretrained --num-classes 400 --new-height 256 --new-width 340 --input-size 224 --num-segments 7 --calibration +``` +python feat_extract_pytorch.py --config-file ./configuration/XXX.yaml ``` -### INT8 Inference +### [How to get speed measurement?](https://cv.gluon.ai/build/examples_torch_action_recognition/speed.html) ``` +python get_flops.py --config-file ./configuration/XXX.yaml +python get_fps.py --config-file ./configuration/XXX.yaml +``` -export CPUs=`lscpu | grep 'Core(s) per socket' | awk '{print $4}'` -export OMP_NUM_THREADS=${CPUs} -export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 -# dummy data -python test_recognizer.py --model inceptionv3_ucf101 --mode hybrid --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --quantized --benchmark +## MXNet Tutorial -# real data -python test_recognizer.py --model inceptionv3_ucf101 --mode hybrid --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --quantized -# deploy static model -python test_recognizer.py --model inceptionv3_ucf101 --deploy --model-prefix ./model/inceptionv3_ucf101-quantized-naive --input-size 299 --new-height 340 --new-width 450 --batch-size 64 --num-segments 3 --benchmark +### [How to train?](https://cv.gluon.ai/build/examples_action_recognition/dive_deep_i3d_kinetics400.html) +MXNet codebase adopts argparser, hence requiring many arguments. Please check [model zoo page](https://cv.gluon.ai/model_zoo/action_recognition.html) for detailed training command. +``` +python train_recognizer.py ``` -Users are also recommended to bind processes to specific cores via `numactl` for better performance, like below: +### [How to evaluate?](https://cv.gluon.ai/build/examples_action_recognition/demo_i3d_kinetics400.html) ``` -numactl --physcpubind=0-27 --membind=0 python test_recognizer.py ... +python test_recognizer.py ``` -## Performance -Below results are collected based on Intel(R) VNNI enabled C5.12xlarge with 24 physical cores. +### [How to extract features?](https://cv.gluon.ai/build/examples_action_recognition/feat_custom.html) + +``` +python feat_extract.py +``` -|model | fp32 Top-1 | int8 Top-1 | -|-- | -- | -- | -inceptionv3_ucf101 |86.92 | 86.55 | -vgg16_ucf101 |81.86 | 81.41 | -resnet18_v1b_kinetics400 |63.29 | 63.14 | -resnet50_v1b_kinetics400 |68.08 | 68.15 | -inceptionv3_kinetics400 |67.93 | 67.92 | +### [How to do inference on your own video?](https://cv.gluon.ai/build/examples_action_recognition/demo_custom.html) + +``` +python inference.py +``` -## References +## MXNet calibration +Please check out [CALIBRATION.md](https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/CALIBRATION.md) for more information on INT8 model calibration and inference. -1. Limin Wang, Yuanjun Xiong, Zhe Wang and Yu Qiao. “Towards Good Practices for Very Deep Two-Stream ConvNets.” arXiv preprint arXiv:1507.02159, 2015. +## Reproducing our arXiv survey paper +Please check out [ARXIV.md](https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/ARXIV.md) for more information on how to get the same dataset and how to reproduce all the methods in our model zoo. diff --git a/scripts/action-recognition/configuration/i3d_resnet50_v1_kinetics400.yaml b/scripts/action-recognition/configuration/i3d_resnet50_v1_kinetics400.yaml index f651332908..cfb10c9b1b 100644 --- a/scripts/action-recognition/configuration/i3d_resnet50_v1_kinetics400.yaml +++ b/scripts/action-recognition/configuration/i3d_resnet50_v1_kinetics400.yaml @@ -45,7 +45,7 @@ CONFIG: MODEL: NAME: 'i3d_resnet50_v1_kinetics400' - PRETRAINED: False + PRETRAINED: True LOG: BASE_PATH: './logs/i3d_resnet50_v1_kinetics400'