Generalized OHEM (#54)

* Generalized OHEM

* remove config

* update docstring

* fixed sort prob

* fixed valid_mask
This commit is contained in:
Jerry Jiarui XU 2020-08-09 23:49:23 +08:00 committed by GitHub
parent fd100e02c4
commit 99e3e5e499
5 changed files with 62 additions and 46 deletions

View File

@ -271,36 +271,23 @@ Usually it is slow if you do not have high speed networking like InfiniBand.
### Launch multiple jobs on a single machine ### Launch multiple jobs on a single machine
If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs, If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
you need to specify different ports (29500 by default) for each job to avoid communication conflict. you need to specify different ports (29500 by default) for each job to avoid communication conflict. Otherwise, there will be error message saying `RuntimeError: Address already in use`.
If you use `dist_train.sh` to launch training jobs, you can set the port in commands. If you use `dist_train.sh` to launch training jobs, you can set the port in commands with environment variable `PORT`.
```shell ```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4 CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4 CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
``` ```
If you use launch training jobs with Slurm, you need to modify the config files (usually the 6th line from the bottom in config files) to set different communication ports. If you use `slurm_train.sh` to launch training jobs, you can set the port in commands with environment variable `MASTER_PORT`.
In `config1.py`,
```python
dist_params = dict(backend='nccl', port=29500)
```
In `config2.py`,
```python
dist_params = dict(backend='nccl', port=29501)
```
Then you can launch two jobs with `config1.py` ang `config2.py`.
```shell ```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR} MASTER_PORT=29500 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR} MASTER_PORT=29501 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
``` ```
Or you could specify port by `---options dist_params.port=29501`
## Useful tools ## Useful tools
We provide lots of useful tools under `tools/` directory. We provide lots of useful tools under `tools/` directory.

View File

@ -25,7 +25,7 @@ model=dict(
decode_head=dict( decode_head=dict(
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) ) sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
``` ```
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. If `thresh` is not specified, pixels of top ``min_kept`` loss will be selected.
## Class Balanced Loss ## Class Balanced Loss
For dataset that is not balanced in classes distribution, you may change the loss weight of each class. For dataset that is not balanced in classes distribution, you may change the loss weight of each class.

View File

@ -10,22 +10,25 @@ class OHEMPixelSampler(BasePixelSampler):
"""Online Hard Example Mining Sampler for segmentation. """Online Hard Example Mining Sampler for segmentation.
Args: Args:
thresh (float): The threshold for hard example selection. Below context (nn.Module): The context of sampler, subclass of
which, are prediction with low confidence. Default: 0.7. :obj:`BaseDecodeHead`.
min_kept (int): The minimum number of predictions to keep. thresh (float, optional): The threshold for hard example selection.
Below which, are prediction with low confidence. If not
specified, the hard examples will be pixels of top ``min_kept``
loss. Default: None.
min_kept (int, optional): The minimum number of predictions to keep.
Default: 100000. Default: 100000.
ignore_index (int): The ignore index for training. Default: 255.
""" """
def __init__(self, thresh=0.7, min_kept=100000, ignore_index=255): def __init__(self, context, thresh=None, min_kept=100000):
super(OHEMPixelSampler, self).__init__() super(OHEMPixelSampler, self).__init__()
self.context = context
assert min_kept > 1 assert min_kept > 1
self.thresh = thresh self.thresh = thresh
self.min_kept = min_kept self.min_kept = min_kept
self.ignore_index = ignore_index
def sample(self, seg_logit, seg_label): def sample(self, seg_logit, seg_label):
""" """Sample pixels that have high loss or with low prediction confidence.
Args: Args:
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
@ -33,21 +36,22 @@ class OHEMPixelSampler(BasePixelSampler):
Returns: Returns:
torch.Tensor: segmentation weight, shape (N, H, W) torch.Tensor: segmentation weight, shape (N, H, W)
""" """
with torch.no_grad(): with torch.no_grad():
assert seg_logit.shape[2:] == seg_label.shape[2:] assert seg_logit.shape[2:] == seg_label.shape[2:]
assert seg_label.shape[1] == 1 assert seg_label.shape[1] == 1
seg_label = seg_label.squeeze(1).long() seg_label = seg_label.squeeze(1).long()
batch_kept = self.min_kept * seg_label.size(0) batch_kept = self.min_kept * seg_label.size(0)
valid_mask = seg_label != self.context.ignore_index
seg_weight = seg_logit.new_zeros(size=seg_label.size())
valid_seg_weight = seg_weight[valid_mask]
if self.thresh is not None:
seg_prob = F.softmax(seg_logit, dim=1) seg_prob = F.softmax(seg_logit, dim=1)
mask = seg_label.contiguous().view(-1, ) != self.ignore_index
tmp_seg_label = seg_label.clone() tmp_seg_label = seg_label.clone().unsqueeze(1)
tmp_seg_label[tmp_seg_label == self.ignore_index] = 0 tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
seg_prob = seg_prob.gather(1, tmp_seg_label.unsqueeze(1)) seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
sort_prob, sort_indices = seg_prob.contiguous().view( sort_prob, sort_indices = seg_prob[valid_mask].sort()
-1, )[mask].contiguous().sort()
if sort_prob.numel() > 0: if sort_prob.numel() > 0:
min_threshold = sort_prob[min(batch_kept, min_threshold = sort_prob[min(batch_kept,
@ -55,10 +59,18 @@ class OHEMPixelSampler(BasePixelSampler):
else: else:
min_threshold = 0.0 min_threshold = 0.0
threshold = max(min_threshold, self.thresh) threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
losses = self.context.loss_decode(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.
seg_weight = seg_logit.new_ones(size=seg_label.size()) seg_weight[valid_mask] = valid_seg_weight
seg_weight = seg_weight.view(-1)
seg_weight[mask][sort_prob < threshold] = 0.
seg_weight = seg_weight.view_as(seg_label)
return seg_weight return seg_weight

View File

@ -73,7 +73,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.align_corners = align_corners self.align_corners = align_corners
if sampler is not None: if sampler is not None:
self.sampler = build_pixel_sampler(sampler) self.sampler = build_pixel_sampler(sampler, context=self)
else: else:
self.sampler = None self.sampler = None

View File

@ -2,20 +2,37 @@ import pytest
import torch import torch
from mmseg.core import OHEMPixelSampler from mmseg.core import OHEMPixelSampler
from mmseg.models.decode_heads import FCNHead
def _context_for_ohem():
return FCNHead(in_channels=32, channels=16, num_classes=19)
def test_ohem_sampler(): def test_ohem_sampler():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# seg_logit and seg_label must be of the same size # seg_logit and seg_label must be of the same size
sampler = OHEMPixelSampler() sampler = OHEMPixelSampler(context=_context_for_ohem())
seg_logit = torch.randn(1, 19, 45, 45) seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89)) seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
sampler.sample(seg_logit, seg_label) sampler.sample(seg_logit, seg_label)
sampler = OHEMPixelSampler() # test with thresh
sampler = OHEMPixelSampler(
context=_context_for_ohem(), thresh=0.7, min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45) seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label) seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0] assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:] assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() > 200
# test w.o thresh
sampler = OHEMPixelSampler(context=_context_for_ohem(), min_kept=200)
seg_logit = torch.randn(1, 19, 45, 45)
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
seg_weight = sampler.sample(seg_logit, seg_label)
assert seg_weight.shape[0] == seg_logit.shape[0]
assert seg_weight.shape[1:] == seg_logit.shape[2:]
assert seg_weight.sum() == 200