Skip to content

Commit

Permalink
update distributed training
Browse files Browse the repository at this point in the history
  • Loading branch information
YuxinZou committed Sep 4, 2020
1 parent 5aaf95b commit c3925f9
Show file tree
Hide file tree
Showing 24 changed files with 591 additions and 283 deletions.
29 changes: 22 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ vedaseg is an open source semantic segmentation toolbox based on PyTorch.

The toolbox supports several popular and semantic segmentation frameworks out of box, *e.g.* DeepLabv3+, DeepLabv3, U-Net, PSPNet, FPN, etc.

- **Deployment and acceleration**
- **Support of deployment and acceleration**

The toolbox can automatically transform and accelerate PyTorch, Onnx and Tensorflow models with TensorRT, can also automatically generate benchmark with given model.
The toolbox can accelerate models using TensorRT, as well as benchmarking.

- **Different training modes**
- **Support of multiple train/test modes**

The toolbox supports both single-label training and multi-label training.
The toolbox supports both distributed and non-distributed modes.

- **Support of multiple train/test tasks**

The toolbox supports both single-label and multi-label tasks.


## License

This project is released under the [Apache 2.0 license](LICENSE).
Expand Down Expand Up @@ -158,10 +163,15 @@ data
Modify some configuration accordingly in the config file like `configs/voc_unet.py`
* for multi-label training use config file `configs/coco_multilabel_unet.py` and modify some configuration, the difference between single-label and multi-label training are mainly in following parameter in config file: `nclasses`, `multi_label`, `metrics` and `criterion`. Currently multi-label training is only supported in coco data format.

2. Run
2. Non-distributed training

```shell
python tools/trainval.py configs/voc_unet.py
python tools/train.py configs/voc_unet.py
```

3. Ditributed training
```shell
./tools/dist_train.sh configs/voc_unet.py gpu_num
```

Snapshots and logs will be generated at `${vedaseg_root}/workdir`.
Expand All @@ -172,12 +182,17 @@ Snapshots and logs will be generated at `${vedaseg_root}/workdir`.

Modify some configuration accordingly in the config file like `configs/voc_unet.py`

2. Run
2. Non-distributed testing

```shell
python tools/test.py configs/voc_unet.py checkpoint_path
```

3. Ditributed testing
```shell
./tools/dist_test.sh configs/voc_unet.py checkpoint_path gpu_num
```

## Inference

1. Config
Expand Down
33 changes: 22 additions & 11 deletions configs/coco_multilabel_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
img_norm_cfg = dict(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_pixel_value=255.0)
norm_cfg = dict(type='BN')
multi_label = True

inference = dict(
Expand All @@ -29,6 +30,7 @@
type='ResNet',
arch='resnet101',
pretrain=True,
norm_cfg=norm_cfg,
),
),
# model/decoder
Expand Down Expand Up @@ -58,7 +60,7 @@
out_channels=256,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand Down Expand Up @@ -86,7 +88,7 @@
out_channels=128,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand Down Expand Up @@ -114,7 +116,7 @@
out_channels=64,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand All @@ -141,7 +143,7 @@
out_channels=32,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand All @@ -167,7 +169,7 @@
out_channels=16,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand Down Expand Up @@ -222,10 +224,13 @@
multi_label=multi_label,
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=8,
num_workers=4,
samples_per_gpu=4,
workers_per_gpu=4,
shuffle=False,
drop_last=False,
pin_memory=True,
Expand Down Expand Up @@ -266,10 +271,13 @@
dict(type='Normalize', **img_norm_cfg),
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=16,
num_workers=4,
samples_per_gpu=8,
workers_per_gpu=4,
shuffle=True,
drop_last=True,
pin_memory=True,
Expand All @@ -284,10 +292,13 @@
multi_label=multi_label,
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=8,
num_workers=4,
samples_per_gpu=8,
workers_per_gpu=4,
shuffle=False,
drop_last=False,
pin_memory=True,
Expand Down
34 changes: 23 additions & 11 deletions configs/coco_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
img_norm_cfg = dict(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_pixel_value=255.0)
norm_cfg = dict(type='BN')
multi_label = False

inference = dict(
Expand All @@ -29,6 +30,7 @@
type='ResNet',
arch='resnet101',
pretrain=True,
norm_cfg=norm_cfg,
),
),
# model/decoder
Expand Down Expand Up @@ -58,7 +60,7 @@
out_channels=256,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand Down Expand Up @@ -86,7 +88,7 @@
out_channels=128,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand Down Expand Up @@ -114,7 +116,7 @@
out_channels=64,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand All @@ -141,7 +143,7 @@
out_channels=32,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand All @@ -167,7 +169,7 @@
out_channels=16,
kernel_size=3,
padding=1,
norm_cfg=dict(type='BN'),
norm_cfg=norm_cfg,
act_cfg=dict(type='Relu', inplace=True),
num_convs=2,
),
Expand All @@ -179,6 +181,7 @@
type='Head',
in_channels=16,
out_channels=nclasses,
norm_cfg=norm_cfg,
num_convs=0,
upsample=dict(
type='Upsample',
Expand Down Expand Up @@ -222,10 +225,13 @@
multi_label=multi_label,
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=8,
num_workers=4,
samples_per_gpu=4,
workers_per_gpu=4,
shuffle=False,
drop_last=False,
pin_memory=True,
Expand Down Expand Up @@ -266,10 +272,13 @@
dict(type='Normalize', **img_norm_cfg),
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=16,
num_workers=4,
samples_per_gpu=8,
workers_per_gpu=4,
shuffle=True,
drop_last=True,
pin_memory=True,
Expand All @@ -284,10 +293,13 @@
multi_label=multi_label,
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=8,
num_workers=4,
samples_per_gpu=8,
workers_per_gpu=4,
shuffle=False,
drop_last=False,
pin_memory=True,
Expand Down
26 changes: 20 additions & 6 deletions configs/voc_deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
img_norm_cfg = dict(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_pixel_value=255.0)
norm_cfg = dict(type='BN')
multi_label = False

inference = dict(
Expand All @@ -28,6 +29,7 @@
arch='resnet101',
replace_stride_with_dilation=[False, False, True],
multi_grid=[1, 2, 4],
norm_cfg=norm_cfg,
),
enhance=dict(
type='ASPP',
Expand All @@ -38,6 +40,7 @@
atrous_rates=[6, 12, 18],
mode='bilinear',
align_corners=True,
norm_cfg=norm_cfg,
dropout=0.1,
),
),
Expand All @@ -48,6 +51,7 @@
in_channels=256,
inter_channels=256,
out_channels=nclasses,
norm_cfg=norm_cfg,
num_convs=1,
upsample=dict(
type='Upsample',
Expand Down Expand Up @@ -78,6 +82,7 @@
dict(type='IoU', num_classes=nclasses),
dict(type='MIoU', num_classes=nclasses, average='equal'),
],
dist_params=dict(backend='nccl'),
)

## 2.1 configuration for test
Expand All @@ -90,10 +95,13 @@
multi_label=multi_label,
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=16,
num_workers=4,
samples_per_gpu=4,
workers_per_gpu=4,
shuffle=False,
drop_last=False,
pin_memory=True,
Expand Down Expand Up @@ -128,10 +136,13 @@
dict(type='Normalize', **img_norm_cfg),
dict(type='ToTensor'),
],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=16,
num_workers=4,
samples_per_gpu=8,
workers_per_gpu=4,
shuffle=True,
drop_last=True,
pin_memory=True,
Expand All @@ -145,10 +156,13 @@
multi_label=multi_label,
),
transforms=inference['transforms'],
sampler=dict(
type='DistributedSampler',
),
dataloader=dict(
type='DataLoader',
batch_size=16,
num_workers=4,
samples_per_gpu=8,
workers_per_gpu=4,
shuffle=False,
drop_last=False,
pin_memory=True,
Expand Down
Loading

0 comments on commit c3925f9

Please sign in to comment.