Generalized OHEM (#54)

* Generalized OHEM

* remove config

* update docstring

* fixed sort prob

* fixed valid_mask
pull/1801/head
Jerry Jiarui XU 2020-08-09 23:49:23 +08:00 committed by GitHub
parent fd100e02c4
commit 99e3e5e499
5 changed files with 62 additions and 46 deletions

View File

@ -271,36 +271,23 @@ Usually it is slow if you do not have high speed networking like InfiniBand.
### Launch multiple jobs on a single machine
If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
you need to specify different ports (29500 by default) for each job to avoid communication conflict.
you need to specify different ports (29500 by default) for each job to avoid communication conflict. Otherwise, there will be error message saying `RuntimeError: Address already in use`.
If you use `dist_train.sh` to launch training jobs, you can set the port in commands.
If you use `dist_train.sh` to launch training jobs, you can set the port in commands with environment variable `PORT`.
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
```
If you use launch training jobs with Slurm, you need to modify the config files (usually the 6th line from the bottom in config files) to set different communication ports.
If you use `slurm_train.sh` to launch training jobs, you can set the port in commands with environment variable `MASTER_PORT`.
In `config1.py`,
```python
dist_params = dict(backend='nccl', port=29500)
```
In `config2.py`,
```python
dist_params = dict(backend='nccl', port=29501)
```
Then you can launch two jobs with `config1.py` ang `config2.py`.
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR}
CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
MASTER_PORT=29500 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
MASTER_PORT=29501 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
```
Or you could specify port by `---options dist_params.port=29501`
## Useful tools
We provide lots of useful tools under `tools/` directory.

View File

@ -25,7 +25,7 @@ model=dict(
decode_head=dict(
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
```
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training.
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. If `thresh` is not specified, pixels of top ``min_kept`` loss will be selected.
## Class Balanced Loss
For dataset that is not balanced in classes distribution, you may change the loss weight of each class.

View File

@ -10,22 +10,25 @@ class OHEMPixelSampler(BasePixelSampler):
"""Online Hard Example Mining Sampler for segmentation.
Args:
thresh (float): The threshold for hard example selection. Below
which, are prediction with low confidence. Default: 0.7.
min_kept (int): The minimum number of predictions to keep.
context (nn.Module): The context of sampler, subclass of
:obj:`BaseDecodeHead`.
thresh (float, optional): The threshold for hard example selection.
Below which, are prediction with low confidence. If not
specified, the hard examples will be pixels of top ``min_kept``
loss. Default: None.
min_kept (int, optional): The minimum number of predictions to keep.
Default: 100000.
ignore_index (int): The ignore index for training. Default: 255.
"""
def __init__(self, thresh=0.7, min_kept=100000, ignore_index=255):
def __init__(self, context, thresh=None, min_kept=100000):
super(OHEMPixelSampler, self).__init__()
self.context = context
assert min_kept > 1
self.thresh = thresh
self.min_kept = min_kept
self.ignore_index = ignore_index
def sample(self, seg_logit, seg_label):
"""
"""Sample pixels that have high loss or with low prediction confidence.
Args:
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
@ -33,32 +36,41 @@ class OHEMPixelSampler(BasePixelSampler):
Returns:
torch.Tensor: segmentation weight, shape (N, H, W)
"""
with torch.no_grad():
assert seg_logit.shape[2:] == seg_label.shape[2:]
assert seg_label.shape[1] == 1
seg_label = seg_label.squeeze(1).long()
batch_kept = self.min_kept * seg_label.size(0)
seg_prob = F.softmax(seg_logit, dim=1)
mask = seg_label.contiguous().view(-1, ) != self.ignore_index
valid_mask = seg_label != self.context.ignore_index
seg_weight = seg_logit.new_zeros(size=seg_label.size())
valid_seg_weight = seg_weight[valid_mask]
if self.thresh is not None:
seg_prob = F.softmax(seg_logit, dim=1)
tmp_seg_label = seg_label.clone()
tmp_seg_label[tmp_seg_label == self.ignore_index] = 0
seg_prob = seg_prob.gather(1, tmp_seg_label.unsqueeze(1))
sort_prob, sort_indices = seg_prob.contiguous().view(
-1, )[mask].contiguous().sort()
tmp_seg_label = seg_label.clone().unsqueeze(1)
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
sort_prob, sort_indices = seg_prob[valid_mask].sort()
if sort_prob.numel() > 0:
min_threshold = sort_prob[min(batch_kept,
sort_prob.numel() - 1)]
if sort_prob.numel() > 0:
min_threshold = sort_prob[min(batch_kept,
sort_prob.numel() - 1)]
else:
min_threshold = 0.0
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
min_threshold = 0.0
threshold = max(min_threshold, self.thresh)
losses = self.context.loss_decode(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.
seg_weight = seg_logit.new_ones(size=seg_label.size())
seg_weight = seg_weight.view(-1)
seg_weight[mask][sort_prob < threshold] = 0.
seg_weight = seg_weight.view_as(seg_label)
seg_weight[valid_mask] = valid_seg_weight
return seg_weight

View File

@ -73,7 +73,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
self.ignore_index = ignore_index
self.align_corners = align_corners
if sampler is not None:
self.sampler = build_pixel_sampler(sampler)
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None

View File

@ -2,20 +2,37 @@ import pytest
import torch
from mmseg.core import OHEMPixelSampler
from mmseg.models.decode_heads import FCNHead
def _context_for_ohem():
return FCNHead(in_channels=32, channels=16, num_classes=19)
def test_ohem_sampler():
with pytest.raises(AssertionError):
# seg_logit and seg_label must be of the same size
sampler = OHEMPixelSampler()
sampler = OHEMPixelSampler(context=_context_for_ohem())
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
sampler.sample(seg_logit, seg_label)
sampler = OHEMPixelSampler()
# test with thresh
sampler = OHEMPixelSampler(
context=_context_for_ohem(), thresh=0.7, min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() > 200
# test w.o thresh
sampler = OHEMPixelSampler(context=_context_for_ohem(), min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200