[Feature] Add cutmix option (#198)
* Add cutmix option * fix code style * add some annotations * add annotation about custom_hooks * check constraint of alpha > 0 * add test cutmix * fix code style * add cutmix to configs/models * add cutmix to configs/resnet * flake8 * emptypull/220/head
parent
b7b520881f
commit
1cde6f6e65
|
@ -8,6 +8,9 @@ log_config = dict(
|
|||
# dict(type='TensorboardLoggerHook')
|
||||
])
|
||||
# yapf:enable
|
||||
# You can register your own hooks like this
|
||||
# custom_hooks=[dict(type='EMAHook')]
|
||||
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet_CIFAR',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='MultiLabelLinearClsHead',
|
||||
num_classes=10,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
|
||||
train_cfg=dict(cutmix=dict(alpha=1.0, num_classes=10)))
|
|
@ -0,0 +1,16 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='MultiLabelLinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)),
|
||||
train_cfg=dict(cutmix=dict(alpha=1.0, num_classes=1000)))
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnet50_cutmix.py',
|
||||
'../_base_/datasets/imagenet_bs32.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
|
@ -1,7 +1,7 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck
|
||||
from ..utils import BatchMixupLayer
|
||||
from ..utils import BatchCutMixLayer, BatchMixupLayer
|
||||
from .base import BaseClassifier
|
||||
|
||||
|
||||
|
@ -24,10 +24,16 @@ class ImageClassifier(BaseClassifier):
|
|||
if head is not None:
|
||||
self.head = build_head(head)
|
||||
|
||||
self.mixup = None
|
||||
self.mixup, self.cutmix = None, None
|
||||
if train_cfg is not None:
|
||||
mixup_cfg = train_cfg.get('mixup', None)
|
||||
self.mixup = BatchMixupLayer(**mixup_cfg)
|
||||
cutmix_cfg = train_cfg.get('cutmix', None)
|
||||
assert mixup_cfg is None or cutmix_cfg is None, \
|
||||
'Mixup and CutMix can not be set simultaneously.'
|
||||
if mixup_cfg is not None:
|
||||
self.mixup = BatchMixupLayer(**mixup_cfg)
|
||||
if cutmix_cfg is not None:
|
||||
self.cutmix = BatchCutMixLayer(**cutmix_cfg)
|
||||
|
||||
self.init_weights(pretrained=pretrained)
|
||||
|
||||
|
@ -56,18 +62,19 @@ class ImageClassifier(BaseClassifier):
|
|||
Args:
|
||||
img (Tensor): of shape (N, C, H, W) encoding input images.
|
||||
Typically these should be mean centered and std scaled.
|
||||
|
||||
gt_label (Tensor): It should be of shape (N, 1) encoding the
|
||||
ground-truth label of input images for single label task. It
|
||||
shoulf be of shape (N, C) encoding the ground-truth label
|
||||
of input images for multi-labels task.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
if self.mixup is not None:
|
||||
img, gt_label = self.mixup(img, gt_label)
|
||||
|
||||
if self.cutmix is not None:
|
||||
img, gt_label = self.cutmix(img, gt_label)
|
||||
|
||||
x = self.extract_feat(img)
|
||||
|
||||
losses = dict()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .channel_shuffle import channel_shuffle
|
||||
from .cutmix import BatchCutMixLayer
|
||||
from .inverted_residual import InvertedResidual
|
||||
from .make_divisible import make_divisible
|
||||
from .mixup import BatchMixupLayer
|
||||
|
@ -6,5 +7,5 @@ from .se_layer import SELayer
|
|||
|
||||
__all__ = [
|
||||
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'BatchMixupLayer',
|
||||
'SELayer'
|
||||
'BatchCutMixLayer', 'SELayer'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BaseCutMixLayer(object, metaclass=ABCMeta):
|
||||
"""Base class for CutMixLayer"""
|
||||
|
||||
def __init__(self):
|
||||
super(BaseCutMixLayer, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def cutmix(self, imgs, gt_label):
|
||||
pass
|
||||
|
||||
|
||||
class BatchCutMixLayer(BaseCutMixLayer):
|
||||
"""CutMix layer for batch CutMix.
|
||||
|
||||
Args:
|
||||
alpha (float): Parameters for Beta distribution. Positive(>0).
|
||||
num_classes (int): The number of classes.
|
||||
cutmix_prob (float): CutMix probability. It should be in range [0, 1]
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, num_classes, cutmix_prob):
|
||||
super(BatchCutMixLayer, self).__init__()
|
||||
|
||||
assert isinstance(alpha, float) and alpha > 0
|
||||
assert isinstance(num_classes, int)
|
||||
assert isinstance(cutmix_prob, float) and 0.0 <= cutmix_prob <= 1.0
|
||||
|
||||
self.alpha = alpha
|
||||
self.num_classes = num_classes
|
||||
self.cutmix_prob = cutmix_prob
|
||||
|
||||
def rand_bbox(self, size, lam):
|
||||
W = size[2]
|
||||
H = size[3]
|
||||
cut_rat = np.sqrt(1. - lam)
|
||||
cut_w = np.int(W * cut_rat)
|
||||
cut_h = np.int(H * cut_rat)
|
||||
|
||||
# uniform
|
||||
cx = np.random.randint(W)
|
||||
cy = np.random.randint(H)
|
||||
|
||||
bbx1 = np.clip(cx - cut_w // 2, 0, W)
|
||||
bby1 = np.clip(cy - cut_h // 2, 0, H)
|
||||
bbx2 = np.clip(cx + cut_w // 2, 0, W)
|
||||
bby2 = np.clip(cy + cut_h // 2, 0, H)
|
||||
|
||||
return bbx1, bby1, bbx2, bby2
|
||||
|
||||
def cutmix(self, img, gt_label):
|
||||
r = np.random.rand(1)
|
||||
if r < self.cutmix_prob:
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
batch_size = img.size(0)
|
||||
index = torch.randperm(batch_size)
|
||||
one_hot_gt_label = F.one_hot(
|
||||
gt_label, num_classes=self.num_classes)
|
||||
bbx1, bby1, bbx2, bby2 = self.rand_bbox(img.size(), lam)
|
||||
img[:, :, bbx1:bbx2, bby1:bby2] = \
|
||||
img[index, :, bbx1:bbx2, bby1:bby2]
|
||||
# adjust lambda to exactly match pixel ratio
|
||||
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
|
||||
(img.size(-1) * img.size(-2)))
|
||||
mixed_gt_label = lam * one_hot_gt_label + (
|
||||
1 - lam) * one_hot_gt_label[index, :]
|
||||
return img, mixed_gt_label
|
||||
else:
|
||||
one_hot_gt_label = F.one_hot(
|
||||
gt_label, num_classes=self.num_classes)
|
||||
return img, one_hot_gt_label
|
||||
|
||||
def __call__(self, img, gt_label):
|
||||
return self.cutmix(img, gt_label)
|
|
@ -3,8 +3,9 @@ import torch
|
|||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import (BatchMixupLayer, InvertedResidual, SELayer,
|
||||
channel_shuffle, make_divisible)
|
||||
from mmcls.models.utils import (BatchCutMixLayer, BatchMixupLayer,
|
||||
InvertedResidual, SELayer, channel_shuffle,
|
||||
make_divisible)
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
|
@ -127,3 +128,16 @@ def test_mixup():
|
|||
mixed_img, mixed_label = mixup_layer(img, label)
|
||||
assert mixed_img.shape == torch.Size((16, 3, 32, 32))
|
||||
assert mixed_label.shape == torch.Size((16, num_classes))
|
||||
|
||||
|
||||
def test_cutmix():
|
||||
|
||||
alpha = 1.0
|
||||
num_classes = 10
|
||||
cutmix_prob = 1.0
|
||||
img = torch.randn(16, 3, 32, 32)
|
||||
label = torch.randint(0, 10, (16, ))
|
||||
mixup_layer = BatchCutMixLayer(alpha, num_classes, cutmix_prob)
|
||||
mixed_img, mixed_label = mixup_layer(img, label)
|
||||
assert mixed_img.shape == torch.Size((16, 3, 32, 32))
|
||||
assert mixed_label.shape == torch.Size((16, num_classes))
|
||||
|
|
|
@ -30,6 +30,34 @@ def test_image_classifier():
|
|||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_image_classifier_with_cutmix():
|
||||
|
||||
# Test cutmix in ImageClassifier
|
||||
model_cfg = dict(
|
||||
backbone=dict(
|
||||
type='ResNet_CIFAR',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='MultiLabelLinearClsHead',
|
||||
num_classes=10,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0,
|
||||
use_soft=True)),
|
||||
train_cfg=dict(
|
||||
cutmix=dict(alpha=1.0, num_classes=10, cutmix_prob=1.0)))
|
||||
img_classifier = ImageClassifier(**model_cfg)
|
||||
img_classifier.init_weights()
|
||||
imgs = torch.randn(16, 3, 32, 32)
|
||||
label = torch.randint(0, 10, (16, ))
|
||||
|
||||
losses = img_classifier.forward_train(imgs, label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_image_classifier_with_label_smooth_loss():
|
||||
|
||||
# Test mixup in ImageClassifier
|
||||
|
|
Loading…
Reference in New Issue