Skip to content

Commit

Permalink
pytorch multigrid updates (#1620)
Browse files Browse the repository at this point in the history
* pytorch multigrid updates

(1) modify config logic for multigrid usage, now support (a) open long cycle and short cycle independently. (b) define long cycle by users
(2) rename multiGridSampler to multiGridHelper
(3) re-implement code for computing batch len in multi grid scenario.

* tiny bug fix

* tiny bug fix

Co-authored-by: Chunhui Liu <[email protected]>
  • Loading branch information
ECHO960 and Chunhui Liu authored Feb 26, 2021
1 parent 5ec39fa commit ec76b7d
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 72 deletions.
2 changes: 1 addition & 1 deletion gluoncv/torch/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

from .video_cls.dataset_classification import VideoClsDataset
from .video_cls.dataset_classification import build_dataloader, build_dataloader_test
from .video_cls.multigrid_helper import multiGridSampler, MultiGridBatchSampler
from .video_cls.multigrid_helper import multiGridHelper, MultiGridBatchSampler
from .coot.dataloader import create_datasets, create_loaders
23 changes: 13 additions & 10 deletions gluoncv/torch/data/video_cls/dataset_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import Dataset

from ..transforms.videotransforms import video_transforms, volume_transforms
from .multigrid_helper import multiGridSampler, MultiGridBatchSampler
from .multigrid_helper import multiGridHelper, MultiGridBatchSampler


__all__ = ['VideoClsDataset', 'build_dataloader', 'build_dataloader_test']
Expand Down Expand Up @@ -45,12 +45,12 @@ def __init__(self, anno_path, data_path, mode='train', clip_len=8,

if (mode == 'train'):
if self.use_multigrid:
self.MG_sampler = multiGridSampler()
self.mg_helper = multiGridHelper()
self.data_transform = []
for alpha in range(self.MG_sampler.mod_long):
for alpha in range(self.mg_helper.mod_long):
tmp = []
for beta in range(self.MG_sampler.mod_short):
info = self.MG_sampler.get_resize(alpha, beta)
for beta in range(self.mg_helper.mod_short):
info = self.mg_helper.get_resize(alpha, beta)
scale_s = info[1]
tmp.append(video_transforms.Compose([
video_transforms.Resize(int(self.short_side_size / scale_s),
Expand Down Expand Up @@ -108,7 +108,7 @@ def __getitem__(self, index):
if self.mode == 'train':
if self.use_multigrid is True:
index, alpha, beta = index
info = self.MG_sampler.get_resize(alpha, beta)
info = self.mg_helper.get_resize(alpha, beta)
scale_t = info[0]
data_transform_func = self.data_transform[alpha][beta]
else:
Expand Down Expand Up @@ -241,7 +241,8 @@ def build_dataloader(cfg):
train_dataset = VideoClsDataset(anno_path=cfg.CONFIG.DATA.TRAIN_ANNO_PATH,
data_path=cfg.CONFIG.DATA.TRAIN_DATA_PATH,
mode='train',
use_multigrid=cfg.CONFIG.DATA.MULTIGRID,
use_multigrid=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE \
or cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE,
clip_len=cfg.CONFIG.DATA.CLIP_LEN,
frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE,
num_segment=cfg.CONFIG.DATA.NUM_SEGMENT,
Expand All @@ -254,7 +255,7 @@ def build_dataloader(cfg):
val_dataset = VideoClsDataset(anno_path=cfg.CONFIG.DATA.VAL_ANNO_PATH,
data_path=cfg.CONFIG.DATA.VAL_DATA_PATH,
mode='validation',
use_multigrid=cfg.CONFIG.DATA.MULTIGRID,
use_multigrid=False,
clip_len=cfg.CONFIG.DATA.CLIP_LEN,
frame_sample_rate=cfg.CONFIG.DATA.FRAME_RATE,
num_segment=cfg.CONFIG.DATA.NUM_SEGMENT,
Expand All @@ -273,9 +274,11 @@ def build_dataloader(cfg):
val_sampler = None

mg_sampler = None
if cfg.CONFIG.DATA.MULTIGRID:
if cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE or cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE:
mg_sampler = MultiGridBatchSampler(train_sampler, batch_size=cfg.CONFIG.TRAIN.BATCH_SIZE,
drop_last=True)
drop_last=True,
use_long=cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE,
use_short=cfg.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False,
num_workers=9, pin_memory=True,
batch_sampler=mg_sampler)
Expand Down
141 changes: 82 additions & 59 deletions gluoncv/torch/data/video_cls/multigrid_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
from torch._six import int_classes as _int_classes


__all__ = ['multiGridSampler', 'MultiGridBatchSampler']
__all__ = ['multiGridHelper', 'MultiGridBatchSampler']


sq2 = np.sqrt(2)
class multiGridSampler(object):
class multiGridHelper(object):
"""
A Multigrid Method for Efficiently Training Video Models
Chao-Yuan Wu, Ross Girshick, Kaiming He, Christoph Feichtenhofer, Philipp Krähenbühl
CVPR 2020, https://arxiv.org/abs/1912.00998
"""
def __init__(self):
# Scale: [T, H, W]
self.long_cycle = np.asarray([[1, 1, 1], [2, 1, 1], [2, sq2, sq2]])[::-1]
self.short_cycle = np.asarray([[1, 1, 1], [1, sq2, sq2], [1, 2, 2]])[::-1]
self.short_cycle_sp = np.asarray([[1, 1, 1], [1, sq2, sq2], [1, sq2, sq2]])[::-1]
Expand Down Expand Up @@ -54,7 +55,15 @@ class MultiGridBatchSampler(Sampler):
Chao-Yuan Wu, Ross Girshick, Kaiming He, Christoph Feichtenhofer, Philipp Krähenbühl
CVPR 2020, https://arxiv.org/abs/1912.00998
"""
def __init__(self, sampler, batch_size, drop_last):
def __init__(self, sampler, batch_size, drop_last, use_long=False, use_short=False):
'''
:param sampler: torch.utils.data.Sample
:param batch_size: int
:param drop_last: bool
:param use_long: bool
:param use_short: bool
Apply batch collecting function based on multiGridHelper definition
'''
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
Expand All @@ -66,85 +75,99 @@ def __init__(self, sampler, batch_size, drop_last):
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
if not isinstance(use_long, bool):
raise ValueError("use_long should be a boolean value, but got "
"use_long={}".format(use_long))
if not isinstance(use_short, bool):
raise ValueError("use_short should be a boolean value, but got "
"use_short={}".format(use_short))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

self.MG_sampler = multiGridSampler()
self.alpha = self.MG_sampler.mod_long - 1
self.mg_helper = multiGridHelper()
# single grid setting
self.alpha = self.mg_helper.mod_long - 1
self.beta = self.mg_helper.mod_short - 1
self.short_cycle_label = False
if use_long:
self.activate_long_cycle()
if use_short:
self.activate_short_cycle()
self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta)

def activate_short_cycle(self):
self.short_cycle_label = True
self.beta = 0
self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta)
self.label = True

def deactivate(self):
self.label = False
self.alpha = self.MG_sampler.mod_long - 1

def activate(self):
self.label = True
def activate_long_cycle(self):
self.alpha = 0

def deactivate(self):
self.alpha = self.mg_helper.mod_long - 1
self.beta = self.mg_helper.mod_short - 1
self.short_cycle_label = False

def __iter__(self):
batch = []
if self.label:
if self.short_cycle_label:
self.beta = 0
else:
self.beta = self.MG_sampler.mod_short - 1
self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta)
self.beta = self.mg_helper.mod_short - 1
self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta)
for idx in self.sampler:
batch.append([idx, self.alpha, self.beta])
if len(batch) == self.batch_size*self.batch_scale:
yield batch
batch = []
if self.label:
self.beta = (self.beta + 1)%self.MG_sampler.mod_short
self.batch_scale = self.MG_sampler.get_scale(self.alpha, self.beta)
if self.short_cycle_label:
self.beta = (self.beta + 1) % self.mg_helper.mod_short
self.batch_scale = self.mg_helper.get_scale(self.alpha, self.beta)

if len(batch) > 0 and not self.drop_last:
yield batch

def step_alpha(self):
self.alpha = (self.alpha + 1)%self.MG_sampler.mod_long

def compute_lr_milestone(self, lr_milestone):
"""
long cycle milestones
"""
self.len_long = self.MG_sampler.mod_long
self.n_epoch_long = 0
for x in range(self.len_long):
self.n_epoch_long += self.MG_sampler.get_scale_alpha(x)
lr_long_cycle = []
for i, _ in enumerate(lr_milestone):
if i == 0:
pre = 0
else:
pre = lr_milestone[i-1]
cycle_length = (lr_milestone[i] - pre) // self.n_epoch_long
bonus = (lr_milestone[i] - pre)%self.n_epoch_long // self.len_long
for j in range(self.len_long)[::-1]:
pre = pre + cycle_length*(2**j) + bonus
if j == 0:
pre = lr_milestone[i]
lr_long_cycle.append(pre)
lr_long_cycle.append(0)
lr_long_cycle = sorted(lr_long_cycle)
return lr_long_cycle
def step_long_cycle(self):
self.alpha = (self.alpha + 1) % self.mg_helper.mod_long

# def compute_lr_milestone(self, lr_milestone):
# """
# long cycle milestones, deprecated. Define long cycle in config files
# """
# self.len_long = self.mg_helper.mod_long
# self.n_epoch_long = 0
# for x in range(self.len_long):
# self.n_epoch_long += self.mg_helper.get_scale_alpha(x)
# lr_long_cycle = []
# for i, _ in enumerate(lr_milestone):
# if i == 0:
# pre = 0
# else:
# pre = lr_milestone[i-1]
# cycle_length = (lr_milestone[i] - pre) // self.n_epoch_long
# bonus = (lr_milestone[i] - pre)%self.n_epoch_long // self.len_long
# for j in range(self.len_long)[::-1]:
# pre = pre + cycle_length*(2**j) + bonus
# if j == 0:
# pre = lr_milestone[i]
# lr_long_cycle.append(pre)
# lr_long_cycle.append(0)
# lr_long_cycle = sorted(lr_long_cycle)
# return lr_long_cycle

def __len__(self):
self.len_short = self.MG_sampler.mod_short
self.n_epoch_short = 0
for x in range(self.len_short):
self.n_epoch_short += self.MG_sampler.get_scale_beta(x)
short_batch_size = self.batch_size * self.MG_sampler.get_scale_alpha(self.alpha)
num_short = len(self.sampler) // short_batch_size

total = num_short // self.n_epoch_short * self.len_short
remain = self.n_epoch_short
for x in range(self.len_short):
remain = remain - (2**x)
scale_per_short_cycle = 0
for x in range(self.mg_helper.mod_short):
scale_per_short_cycle += self.mg_helper.get_scale(self.alpha, x)
num_full_short_cycle = len(self.sampler) // (self.batch_size * scale_per_short_cycle)

total = num_full_short_cycle * self.mg_helper.mod_short
remain = len(self.sampler) % (self.batch_size * scale_per_short_cycle)
for x in range(self.mg_helper.mod_short):
remain = remain - self.mg_helper.get_scale(self.alpha, x)*self.batch_size
if remain >= 0 or (remain < 0 and self.drop_last is False):
total += 1
if remain <= 0:
break
total = total + int(num_short%self.n_epoch_short >= remain)

assert remain <= 0
return total
7 changes: 5 additions & 2 deletions gluoncv/torch/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@
# Resume training from a specific epoch. Set to -1 means train from beginning.
_C.CONFIG.TRAIN.RESUME_EPOCH: -1

# Whether to use multigrid training to speed up.
_C.CONFIG.TRAIN.MULTIGRID = CN(new_allowed=True)
_C.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE = False
_C.CONFIG.TRAIN.MULTIGRID.USE_SHORT_CYCLE = False
_C.CONFIG.TRAIN.MULTIGRID.LONG_CYCLE_EPOCH = [10, 20, 30]

_C.CONFIG.VAL = CN(new_allowed=True)
# Evaluate model on test data every eval period epochs.
Expand All @@ -91,8 +96,6 @@
_C.CONFIG.DATA.VAL_DATA_PATH = ''
# The number of classes to predict for the model.
_C.CONFIG.DATA.NUM_CLASSES = 400
# Whether to use multigrid training to speed up.
_C.CONFIG.DATA.MULTIGRID = False
# The number of frames of the input clip.
_C.CONFIG.DATA.CLIP_LEN = 16
# The video sampling rate of the input clip.
Expand Down
5 changes: 5 additions & 0 deletions scripts/action-recognition/train_ddp_shortonly_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gluoncv.torch.engine.config import get_cfg_defaults
from gluoncv.torch.engine.launch import spawn_workers
from gluoncv.torch.utils.utils import build_log_dir
from gluoncv.torch.utils.lr_policy import GradualWarmupScheduler


def main_worker(cfg):
Expand Down Expand Up @@ -67,6 +68,10 @@ def main_worker(cfg):
else:
scheduler.step()

if cfg.CONFIG.TRAIN.MULTIGRID.USE_LONG_CYCLE:
if epoch in cfg.CONFIG.TRAIN.MULTIGRID.LONG_CYCLE_EPOCH:
mg_sampler.step_long_cycle()

if epoch % cfg.CONFIG.VAL.FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1:
validation_classification(model, val_loader, epoch, criterion, cfg, writer)

Expand Down

0 comments on commit ec76b7d

Please sign in to comment.