[Feature] Add cutmix option (#198)

* Add cutmix option

* fix code style

* add some annotations

* add annotation about custom_hooks

* check constraint of alpha > 0

* add test cutmix

* fix code style

* add cutmix to configs/models

* add cutmix to configs/resnet

* flake8

* empty
pull/220/head
whcao 2021-04-14 21:27:42 +08:00 committed by GitHub
parent b7b520881f
commit 1cde6f6e65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 178 additions and 8 deletions

View File

@ -8,6 +8,9 @@ log_config = dict(
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# You can register your own hooks like this
# custom_hooks=[dict(type='EMAHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None

View File

@ -0,0 +1,16 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
train_cfg=dict(cutmix=dict(alpha=1.0, num_classes=10)))

View File

@ -0,0 +1,16 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
train_cfg=dict(cutmix=dict(alpha=1.0, num_classes=1000)))

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/resnet50_cutmix.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]

View File

@ -1,7 +1,7 @@
import torch.nn as nn
from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck
from ..utils import BatchMixupLayer
from ..utils import BatchCutMixLayer, BatchMixupLayer
from .base import BaseClassifier
@ -24,10 +24,16 @@ class ImageClassifier(BaseClassifier):
if head is not None:
self.head = build_head(head)
self.mixup = None
self.mixup, self.cutmix = None, None
if train_cfg is not None:
mixup_cfg = train_cfg.get('mixup', None)
self.mixup = BatchMixupLayer(**mixup_cfg)
cutmix_cfg = train_cfg.get('cutmix', None)
assert mixup_cfg is None or cutmix_cfg is None, \
'Mixup and CutMix can not be set simultaneously.'
if mixup_cfg is not None:
self.mixup = BatchMixupLayer(**mixup_cfg)
if cutmix_cfg is not None:
self.cutmix = BatchCutMixLayer(**cutmix_cfg)
self.init_weights(pretrained=pretrained)
@ -56,18 +62,19 @@ class ImageClassifier(BaseClassifier):
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
gt_label (Tensor): It should be of shape (N, 1) encoding the
ground-truth label of input images for single label task. It
shoulf be of shape (N, C) encoding the ground-truth label
of input images for multi-labels task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
if self.mixup is not None:
img, gt_label = self.mixup(img, gt_label)
if self.cutmix is not None:
img, gt_label = self.cutmix(img, gt_label)
x = self.extract_feat(img)
losses = dict()

View File

@ -1,4 +1,5 @@
from .channel_shuffle import channel_shuffle
from .cutmix import BatchCutMixLayer
from .inverted_residual import InvertedResidual
from .make_divisible import make_divisible
from .mixup import BatchMixupLayer
@ -6,5 +7,5 @@ from .se_layer import SELayer
__all__ = [
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'BatchMixupLayer',
'SELayer'
'BatchCutMixLayer', 'SELayer'
]

View File

@ -0,0 +1,80 @@
from abc import ABCMeta, abstractmethod
import numpy as np
import torch
import torch.nn.functional as F
class BaseCutMixLayer(object, metaclass=ABCMeta):
"""Base class for CutMixLayer"""
def __init__(self):
super(BaseCutMixLayer, self).__init__()
@abstractmethod
def cutmix(self, imgs, gt_label):
pass
class BatchCutMixLayer(BaseCutMixLayer):
"""CutMix layer for batch CutMix.
Args:
alpha (float): Parameters for Beta distribution. Positive(>0).
num_classes (int): The number of classes.
cutmix_prob (float): CutMix probability. It should be in range [0, 1]
"""
def __init__(self, alpha, num_classes, cutmix_prob):
super(BatchCutMixLayer, self).__init__()
assert isinstance(alpha, float) and alpha > 0
assert isinstance(num_classes, int)
assert isinstance(cutmix_prob, float) and 0.0 <= cutmix_prob <= 1.0
self.alpha = alpha
self.num_classes = num_classes
self.cutmix_prob = cutmix_prob
def rand_bbox(self, size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def cutmix(self, img, gt_label):
r = np.random.rand(1)
if r < self.cutmix_prob:
lam = np.random.beta(self.alpha, self.alpha)
batch_size = img.size(0)
index = torch.randperm(batch_size)
one_hot_gt_label = F.one_hot(
gt_label, num_classes=self.num_classes)
bbx1, bby1, bbx2, bby2 = self.rand_bbox(img.size(), lam)
img[:, :, bbx1:bbx2, bby1:bby2] = \
img[index, :, bbx1:bbx2, bby1:bby2]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
(img.size(-1) * img.size(-2)))
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return img, mixed_gt_label
else:
one_hot_gt_label = F.one_hot(
gt_label, num_classes=self.num_classes)
return img, one_hot_gt_label
def __call__(self, img, gt_label):
return self.cutmix(img, gt_label)

View File

@ -3,8 +3,9 @@ import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import (BatchMixupLayer, InvertedResidual, SELayer,
channel_shuffle, make_divisible)
from mmcls.models.utils import (BatchCutMixLayer, BatchMixupLayer,
InvertedResidual, SELayer, channel_shuffle,
make_divisible)
def is_norm(modules):
@ -127,3 +128,16 @@ def test_mixup():
mixed_img, mixed_label = mixup_layer(img, label)
assert mixed_img.shape == torch.Size((16, 3, 32, 32))
assert mixed_label.shape == torch.Size((16, num_classes))
def test_cutmix():
alpha = 1.0
num_classes = 10
cutmix_prob = 1.0
img = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
mixup_layer = BatchCutMixLayer(alpha, num_classes, cutmix_prob)
mixed_img, mixed_label = mixup_layer(img, label)
assert mixed_img.shape == torch.Size((16, 3, 32, 32))
assert mixed_label.shape == torch.Size((16, num_classes))

View File

@ -30,6 +30,34 @@ def test_image_classifier():
assert losses['loss'].item() > 0
def test_image_classifier_with_cutmix():
# Test cutmix in ImageClassifier
model_cfg = dict(
backbone=dict(
type='ResNet_CIFAR',
depth=50,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='MultiLabelLinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
use_soft=True)),
train_cfg=dict(
cutmix=dict(alpha=1.0, num_classes=10, cutmix_prob=1.0)))
img_classifier = ImageClassifier(**model_cfg)
img_classifier.init_weights()
imgs = torch.randn(16, 3, 32, 32)
label = torch.randint(0, 10, (16, ))
losses = img_classifier.forward_train(imgs, label)
assert losses['loss'].item() > 0
def test_image_classifier_with_label_smooth_loss():
# Test mixup in ImageClassifier