Generalized OHEM (#54)
* Generalized OHEM * remove config * update docstring * fixed sort prob * fixed valid_maskpull/1801/head
parent
fd100e02c4
commit
99e3e5e499
|
@ -271,36 +271,23 @@ Usually it is slow if you do not have high speed networking like InfiniBand.
|
|||
### Launch multiple jobs on a single machine
|
||||
|
||||
If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
|
||||
you need to specify different ports (29500 by default) for each job to avoid communication conflict.
|
||||
you need to specify different ports (29500 by default) for each job to avoid communication conflict. Otherwise, there will be error message saying `RuntimeError: Address already in use`.
|
||||
|
||||
If you use `dist_train.sh` to launch training jobs, you can set the port in commands.
|
||||
If you use `dist_train.sh` to launch training jobs, you can set the port in commands with environment variable `PORT`.
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
|
||||
```
|
||||
|
||||
If you use launch training jobs with Slurm, you need to modify the config files (usually the 6th line from the bottom in config files) to set different communication ports.
|
||||
If you use `slurm_train.sh` to launch training jobs, you can set the port in commands with environment variable `MASTER_PORT`.
|
||||
|
||||
In `config1.py`,
|
||||
```python
|
||||
dist_params = dict(backend='nccl', port=29500)
|
||||
```
|
||||
|
||||
In `config2.py`,
|
||||
```python
|
||||
dist_params = dict(backend='nccl', port=29501)
|
||||
```
|
||||
|
||||
Then you can launch two jobs with `config1.py` ang `config2.py`.
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR}
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
|
||||
MASTER_PORT=29500 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
|
||||
MASTER_PORT=29501 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
|
||||
```
|
||||
|
||||
Or you could specify port by `---options dist_params.port=29501`
|
||||
|
||||
## Useful tools
|
||||
|
||||
We provide lots of useful tools under `tools/` directory.
|
||||
|
|
|
@ -25,7 +25,7 @@ model=dict(
|
|||
decode_head=dict(
|
||||
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
|
||||
```
|
||||
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training.
|
||||
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. If `thresh` is not specified, pixels of top ``min_kept`` loss will be selected.
|
||||
|
||||
## Class Balanced Loss
|
||||
For dataset that is not balanced in classes distribution, you may change the loss weight of each class.
|
||||
|
|
|
@ -10,22 +10,25 @@ class OHEMPixelSampler(BasePixelSampler):
|
|||
"""Online Hard Example Mining Sampler for segmentation.
|
||||
|
||||
Args:
|
||||
thresh (float): The threshold for hard example selection. Below
|
||||
which, are prediction with low confidence. Default: 0.7.
|
||||
min_kept (int): The minimum number of predictions to keep.
|
||||
context (nn.Module): The context of sampler, subclass of
|
||||
:obj:`BaseDecodeHead`.
|
||||
thresh (float, optional): The threshold for hard example selection.
|
||||
Below which, are prediction with low confidence. If not
|
||||
specified, the hard examples will be pixels of top ``min_kept``
|
||||
loss. Default: None.
|
||||
min_kept (int, optional): The minimum number of predictions to keep.
|
||||
Default: 100000.
|
||||
ignore_index (int): The ignore index for training. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self, thresh=0.7, min_kept=100000, ignore_index=255):
|
||||
def __init__(self, context, thresh=None, min_kept=100000):
|
||||
super(OHEMPixelSampler, self).__init__()
|
||||
self.context = context
|
||||
assert min_kept > 1
|
||||
self.thresh = thresh
|
||||
self.min_kept = min_kept
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def sample(self, seg_logit, seg_label):
|
||||
"""
|
||||
"""Sample pixels that have high loss or with low prediction confidence.
|
||||
|
||||
Args:
|
||||
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
|
||||
|
@ -33,32 +36,41 @@ class OHEMPixelSampler(BasePixelSampler):
|
|||
|
||||
Returns:
|
||||
torch.Tensor: segmentation weight, shape (N, H, W)
|
||||
|
||||
"""
|
||||
with torch.no_grad():
|
||||
assert seg_logit.shape[2:] == seg_label.shape[2:]
|
||||
assert seg_label.shape[1] == 1
|
||||
seg_label = seg_label.squeeze(1).long()
|
||||
batch_kept = self.min_kept * seg_label.size(0)
|
||||
seg_prob = F.softmax(seg_logit, dim=1)
|
||||
mask = seg_label.contiguous().view(-1, ) != self.ignore_index
|
||||
valid_mask = seg_label != self.context.ignore_index
|
||||
seg_weight = seg_logit.new_zeros(size=seg_label.size())
|
||||
valid_seg_weight = seg_weight[valid_mask]
|
||||
if self.thresh is not None:
|
||||
seg_prob = F.softmax(seg_logit, dim=1)
|
||||
|
||||
tmp_seg_label = seg_label.clone()
|
||||
tmp_seg_label[tmp_seg_label == self.ignore_index] = 0
|
||||
seg_prob = seg_prob.gather(1, tmp_seg_label.unsqueeze(1))
|
||||
sort_prob, sort_indices = seg_prob.contiguous().view(
|
||||
-1, )[mask].contiguous().sort()
|
||||
tmp_seg_label = seg_label.clone().unsqueeze(1)
|
||||
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
|
||||
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
|
||||
sort_prob, sort_indices = seg_prob[valid_mask].sort()
|
||||
|
||||
if sort_prob.numel() > 0:
|
||||
min_threshold = sort_prob[min(batch_kept,
|
||||
sort_prob.numel() - 1)]
|
||||
if sort_prob.numel() > 0:
|
||||
min_threshold = sort_prob[min(batch_kept,
|
||||
sort_prob.numel() - 1)]
|
||||
else:
|
||||
min_threshold = 0.0
|
||||
threshold = max(min_threshold, self.thresh)
|
||||
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
||||
else:
|
||||
min_threshold = 0.0
|
||||
threshold = max(min_threshold, self.thresh)
|
||||
losses = self.context.loss_decode(
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=None,
|
||||
ignore_index=self.context.ignore_index,
|
||||
reduction_override='none')
|
||||
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
||||
_, sort_indices = losses[valid_mask].sort(descending=True)
|
||||
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
||||
|
||||
seg_weight = seg_logit.new_ones(size=seg_label.size())
|
||||
seg_weight = seg_weight.view(-1)
|
||||
seg_weight[mask][sort_prob < threshold] = 0.
|
||||
seg_weight = seg_weight.view_as(seg_label)
|
||||
seg_weight[valid_mask] = valid_seg_weight
|
||||
|
||||
return seg_weight
|
||||
|
|
|
@ -73,7 +73,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler)
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
|
|
|
@ -2,20 +2,37 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmseg.core import OHEMPixelSampler
|
||||
from mmseg.models.decode_heads import FCNHead
|
||||
|
||||
|
||||
def _context_for_ohem():
|
||||
return FCNHead(in_channels=32, channels=16, num_classes=19)
|
||||
|
||||
|
||||
def test_ohem_sampler():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# seg_logit and seg_label must be of the same size
|
||||
sampler = OHEMPixelSampler()
|
||||
sampler = OHEMPixelSampler(context=_context_for_ohem())
|
||||
seg_logit = torch.randn(1, 19, 45, 45)
|
||||
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
|
||||
sampler.sample(seg_logit, seg_label)
|
||||
|
||||
sampler = OHEMPixelSampler()
|
||||
# test with thresh
|
||||
sampler = OHEMPixelSampler(
|
||||
context=_context_for_ohem(), thresh=0.7, min_kept=200)
|
||||
seg_logit = torch.randn(1, 19, 45, 45)
|
||||
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
|
||||
seg_weight = sampler.sample(seg_logit, seg_label)
|
||||
assert seg_weight.shape[0] == seg_logit.shape[0]
|
||||
assert seg_weight.shape[1:] == seg_logit.shape[2:]
|
||||
assert seg_weight.sum() > 200
|
||||
|
||||
# test w.o thresh
|
||||
sampler = OHEMPixelSampler(context=_context_for_ohem(), min_kept=200)
|
||||
seg_logit = torch.randn(1, 19, 45, 45)
|
||||
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
|
||||
seg_weight = sampler.sample(seg_logit, seg_label)
|
||||
assert seg_weight.shape[0] == seg_logit.shape[0]
|
||||
assert seg_weight.shape[1:] == seg_logit.shape[2:]
|
||||
assert seg_weight.sum() == 200
|
||||
|
|
Loading…
Reference in New Issue