Skip to content

Commit

Permalink
add default sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
YuxinZou committed Sep 14, 2020
1 parent c3925f9 commit 12f6d31
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 51 deletions.
18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,14 @@ data
Modify some configuration accordingly in the config file like `configs/voc_unet.py`
* for multi-label training use config file `configs/coco_multilabel_unet.py` and modify some configuration, the difference between single-label and multi-label training are mainly in following parameter in config file: `nclasses`, `multi_label`, `metrics` and `criterion`. Currently multi-label training is only supported in coco data format.

2. Non-distributed training

2. Ditributed training
```shell
python tools/train.py configs/voc_unet.py
./tools/dist_train.sh configs/voc_unet.py gpu_num
```

3. Ditributed training
3. Non-distributed training
```shell
./tools/dist_train.sh configs/voc_unet.py gpu_num
python tools/train.py configs/voc_unet.py
```

Snapshots and logs will be generated at `${vedaseg_root}/workdir`.
Expand All @@ -182,15 +181,14 @@ Snapshots and logs will be generated at `${vedaseg_root}/workdir`.

Modify some configuration accordingly in the config file like `configs/voc_unet.py`

2. Non-distributed testing

2. Ditributed testing
```shell
python tools/test.py configs/voc_unet.py checkpoint_path
./tools/dist_test.sh configs/voc_unet.py checkpoint_path gpu_num
```

3. Ditributed testing
3. Non-distributed testing
```shell
./tools/dist_test.sh configs/voc_unet.py checkpoint_path gpu_num
python tools/test.py configs/voc_unet.py checkpoint_path
```

## Inference
Expand Down
6 changes: 3 additions & 3 deletions configs/coco_multilabel_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down Expand Up @@ -272,7 +272,7 @@
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand All @@ -293,7 +293,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down
10 changes: 5 additions & 5 deletions configs/coco_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down Expand Up @@ -273,7 +273,7 @@
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand All @@ -288,13 +288,13 @@
dataset=dict(
type=dataset_type,
root=dataset_root,
ann_file='instances_val2014.json',
img_prefix='val2014',
ann_file='instances_val2017.json',
img_prefix='val2017',
multi_label=multi_label,
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down
6 changes: 3 additions & 3 deletions configs/voc_deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down Expand Up @@ -137,7 +137,7 @@
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand All @@ -157,7 +157,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down
6 changes: 3 additions & 3 deletions configs/voc_deeplabv3plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down Expand Up @@ -169,7 +169,7 @@
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand All @@ -189,7 +189,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down
6 changes: 3 additions & 3 deletions configs/voc_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down Expand Up @@ -274,7 +274,7 @@
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand All @@ -294,7 +294,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down
6 changes: 3 additions & 3 deletions configs/voc_pspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down Expand Up @@ -140,7 +140,7 @@
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand All @@ -160,7 +160,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down
6 changes: 3 additions & 3 deletions configs/voc_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down Expand Up @@ -269,7 +269,7 @@
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand All @@ -289,7 +289,7 @@
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
type='DefaultSampler',
),
dataloader=dict(
type='DataLoader',
Expand Down
5 changes: 1 addition & 4 deletions vedaseg/dataloaders/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@
def build_dataloader(distributed, num_gpus, cfg, default_args=None):
cfg_ = cfg.copy()

shuffle = cfg_.pop('shuffle')
samples_per_gpu = cfg_.pop('samples_per_gpu')
workers_per_gpu = cfg_.pop('workers_per_gpu')

if distributed:
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu

cfg_.update({'shuffle': shuffle,
'batch_size': batch_size,
cfg_.update({'batch_size': batch_size,
'num_workers': num_workers})

dataloader = build_from_cfg(cfg_, DATALOADERS, default_args)
Expand Down
2 changes: 2 additions & 0 deletions vedaseg/dataloaders/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .builder import build_sampler
from .distributed import DefaultSampler
from .non_distributed import DefaultSampler
9 changes: 6 additions & 3 deletions vedaseg/dataloaders/samplers/builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from ...utils import build_from_cfg
from .registry import SAMPLERS
from .registry import NON_DISTRIBUTED_SAMPLERS, DISTRIBUTED_SAMPLERS


def build_sampler(cfg, default_args=None):
sampler = build_from_cfg(cfg, SAMPLERS, default_args)
def build_sampler(distributed, cfg, default_args=None):
if distributed:
sampler = build_from_cfg(cfg, DISTRIBUTED_SAMPLERS, default_args)
else:
sampler = build_from_cfg(cfg, NON_DISTRIBUTED_SAMPLERS, default_args)

return sampler
13 changes: 13 additions & 0 deletions vedaseg/dataloaders/samplers/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch.utils.data import DistributedSampler

from ...utils import get_dist_info
from .registry import DISTRIBUTED_SAMPLERS


@DISTRIBUTED_SAMPLERS.register_module
class DefaultSampler(DistributedSampler):
"""Default distributed sampler."""

def __init__(self, dataset, shuffle=True):
rank, num_replicas = get_dist_info()
super().__init__(dataset, num_replicas, rank, shuffle)
22 changes: 22 additions & 0 deletions vedaseg/dataloaders/samplers/non_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from torch.utils.data import Sampler

from .registry import NON_DISTRIBUTED_SAMPLERS


@NON_DISTRIBUTED_SAMPLERS.register_module
class DefaultSampler(Sampler):
"""Default non-distributed sampler."""

def __init__(self, dataset, shuffle=True):
self.dataset = dataset
self.shuffle = shuffle

def __iter__(self):
if self.shuffle:
return iter(torch.randperm(len(self.dataset)).tolist())
else:
return iter(range(len(self.dataset)))

def __len__(self):
return len(self.dataset)
7 changes: 2 additions & 5 deletions vedaseg/dataloaders/samplers/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from torch.utils.data import DistributedSampler

from ...utils import Registry

SAMPLERS = Registry('sampler')

SAMPLERS.register_module(DistributedSampler)
DISTRIBUTED_SAMPLERS = Registry('distributed_sampler')
NON_DISTRIBUTED_SAMPLERS = Registry('non_distributed_sampler')
14 changes: 8 additions & 6 deletions vedaseg/runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,16 @@ def _build_dataloader(self, cfg):
transform = build_transform(cfg['transforms'])
dataset = build_dataset(cfg['dataset'], dict(transform=transform))

shuffle = cfg['dataloader'].get('shuffle', False)
sampler = build_sampler(cfg['sampler'], dict(dataset=dataset,
shuffle=shuffle)) if cfg.get(
'sampler') is not None else None
shuffle = cfg['dataloader'].pop('shuffle', False)
sampler = build_sampler(self.distribute,
cfg['sampler'],
dict(dataset=dataset,
shuffle=shuffle))

dataloader = build_dataloader(self.distribute,
self.gpu_num,
cfg['dataloader'], dict(dataset=dataset,
sampler=sampler))
cfg['dataloader'],
dict(dataset=dataset,
sampler=sampler))

return dataloader

0 comments on commit 12f6d31

Please sign in to comment.