From 1cde6f6e6517094efec05776967ff9e4af0c0683 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Wed, 14 Apr 2021 21:27:42 +0800 Subject: [PATCH] [Feature] Add cutmix option (#198) * Add cutmix option * fix code style * add some annotations * add annotation about custom_hooks * check constraint of alpha > 0 * add test cutmix * fix code style * add cutmix to configs/models * add cutmix to configs/resnet * flake8 * empty --- configs/_base_/default_runtime.py | 3 + .../_base_/models/resnet50_cifar_cutmix.py | 16 ++++ configs/_base_/models/resnet50_cutmix.py | 16 ++++ .../resnet/resnet50_b32x8_cutmix_imagenet.py | 5 ++ mmcls/models/classifiers/image.py | 17 ++-- mmcls/models/utils/__init__.py | 3 +- mmcls/models/utils/cutmix.py | 80 +++++++++++++++++++ tests/test_backbones/test_utils.py | 18 ++++- tests/test_classifiers.py | 28 +++++++ 9 files changed, 178 insertions(+), 8 deletions(-) create mode 100644 configs/_base_/models/resnet50_cifar_cutmix.py create mode 100644 configs/_base_/models/resnet50_cutmix.py create mode 100644 configs/resnet/resnet50_b32x8_cutmix_imagenet.py create mode 100644 mmcls/models/utils/cutmix.py diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index ee4228c5..90240e49 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -8,6 +8,9 @@ log_config = dict( # dict(type='TensorboardLoggerHook') ]) # yapf:enable +# You can register your own hooks like this +# custom_hooks=[dict(type='EMAHook')] + dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None diff --git a/configs/_base_/models/resnet50_cifar_cutmix.py b/configs/_base_/models/resnet50_cifar_cutmix.py new file mode 100644 index 00000000..86714d6b --- /dev/null +++ b/configs/_base_/models/resnet50_cifar_cutmix.py @@ -0,0 +1,16 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='ResNet_CIFAR', + depth=50, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='MultiLabelLinearClsHead', + num_classes=10, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)), + train_cfg=dict(cutmix=dict(alpha=1.0, num_classes=10))) diff --git a/configs/_base_/models/resnet50_cutmix.py b/configs/_base_/models/resnet50_cutmix.py new file mode 100644 index 00000000..71de7f0b --- /dev/null +++ b/configs/_base_/models/resnet50_cutmix.py @@ -0,0 +1,16 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='MultiLabelLinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0, use_soft=True)), + train_cfg=dict(cutmix=dict(alpha=1.0, num_classes=1000))) diff --git a/configs/resnet/resnet50_b32x8_cutmix_imagenet.py b/configs/resnet/resnet50_b32x8_cutmix_imagenet.py new file mode 100644 index 00000000..2f8d0ca9 --- /dev/null +++ b/configs/resnet/resnet50_b32x8_cutmix_imagenet.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/resnet50_cutmix.py', + '../_base_/datasets/imagenet_bs32.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] diff --git a/mmcls/models/classifiers/image.py b/mmcls/models/classifiers/image.py index e70d88d0..2cb282b5 100644 --- a/mmcls/models/classifiers/image.py +++ b/mmcls/models/classifiers/image.py @@ -1,7 +1,7 @@ import torch.nn as nn from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck -from ..utils import BatchMixupLayer +from ..utils import BatchCutMixLayer, BatchMixupLayer from .base import BaseClassifier @@ -24,10 +24,16 @@ class ImageClassifier(BaseClassifier): if head is not None: self.head = build_head(head) - self.mixup = None + self.mixup, self.cutmix = None, None if train_cfg is not None: mixup_cfg = train_cfg.get('mixup', None) - self.mixup = BatchMixupLayer(**mixup_cfg) + cutmix_cfg = train_cfg.get('cutmix', None) + assert mixup_cfg is None or cutmix_cfg is None, \ + 'Mixup and CutMix can not be set simultaneously.' + if mixup_cfg is not None: + self.mixup = BatchMixupLayer(**mixup_cfg) + if cutmix_cfg is not None: + self.cutmix = BatchCutMixLayer(**cutmix_cfg) self.init_weights(pretrained=pretrained) @@ -56,18 +62,19 @@ class ImageClassifier(BaseClassifier): Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. - gt_label (Tensor): It should be of shape (N, 1) encoding the ground-truth label of input images for single label task. It shoulf be of shape (N, C) encoding the ground-truth label of input images for multi-labels task. - Returns: dict[str, Tensor]: a dictionary of loss components """ if self.mixup is not None: img, gt_label = self.mixup(img, gt_label) + if self.cutmix is not None: + img, gt_label = self.cutmix(img, gt_label) + x = self.extract_feat(img) losses = dict() diff --git a/mmcls/models/utils/__init__.py b/mmcls/models/utils/__init__.py index 6e981769..c1c1a750 100644 --- a/mmcls/models/utils/__init__.py +++ b/mmcls/models/utils/__init__.py @@ -1,4 +1,5 @@ from .channel_shuffle import channel_shuffle +from .cutmix import BatchCutMixLayer from .inverted_residual import InvertedResidual from .make_divisible import make_divisible from .mixup import BatchMixupLayer @@ -6,5 +7,5 @@ from .se_layer import SELayer __all__ = [ 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'BatchMixupLayer', - 'SELayer' + 'BatchCutMixLayer', 'SELayer' ] diff --git a/mmcls/models/utils/cutmix.py b/mmcls/models/utils/cutmix.py new file mode 100644 index 00000000..e0294c72 --- /dev/null +++ b/mmcls/models/utils/cutmix.py @@ -0,0 +1,80 @@ +from abc import ABCMeta, abstractmethod + +import numpy as np +import torch +import torch.nn.functional as F + + +class BaseCutMixLayer(object, metaclass=ABCMeta): + """Base class for CutMixLayer""" + + def __init__(self): + super(BaseCutMixLayer, self).__init__() + + @abstractmethod + def cutmix(self, imgs, gt_label): + pass + + +class BatchCutMixLayer(BaseCutMixLayer): + """CutMix layer for batch CutMix. + + Args: + alpha (float): Parameters for Beta distribution. Positive(>0). + num_classes (int): The number of classes. + cutmix_prob (float): CutMix probability. It should be in range [0, 1] + """ + + def __init__(self, alpha, num_classes, cutmix_prob): + super(BatchCutMixLayer, self).__init__() + + assert isinstance(alpha, float) and alpha > 0 + assert isinstance(num_classes, int) + assert isinstance(cutmix_prob, float) and 0.0 <= cutmix_prob <= 1.0 + + self.alpha = alpha + self.num_classes = num_classes + self.cutmix_prob = cutmix_prob + + def rand_bbox(self, size, lam): + W = size[2] + H = size[3] + cut_rat = np.sqrt(1. - lam) + cut_w = np.int(W * cut_rat) + cut_h = np.int(H * cut_rat) + + # uniform + cx = np.random.randint(W) + cy = np.random.randint(H) + + bbx1 = np.clip(cx - cut_w // 2, 0, W) + bby1 = np.clip(cy - cut_h // 2, 0, H) + bbx2 = np.clip(cx + cut_w // 2, 0, W) + bby2 = np.clip(cy + cut_h // 2, 0, H) + + return bbx1, bby1, bbx2, bby2 + + def cutmix(self, img, gt_label): + r = np.random.rand(1) + if r < self.cutmix_prob: + lam = np.random.beta(self.alpha, self.alpha) + batch_size = img.size(0) + index = torch.randperm(batch_size) + one_hot_gt_label = F.one_hot( + gt_label, num_classes=self.num_classes) + bbx1, bby1, bbx2, bby2 = self.rand_bbox(img.size(), lam) + img[:, :, bbx1:bbx2, bby1:bby2] = \ + img[index, :, bbx1:bbx2, bby1:bby2] + # adjust lambda to exactly match pixel ratio + lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / + (img.size(-1) * img.size(-2))) + mixed_gt_label = lam * one_hot_gt_label + ( + 1 - lam) * one_hot_gt_label[index, :] + return img, mixed_gt_label + else: + one_hot_gt_label = F.one_hot( + gt_label, num_classes=self.num_classes) + return img, one_hot_gt_label + + def __call__(self, img, gt_label): + return self.cutmix(img, gt_label) diff --git a/tests/test_backbones/test_utils.py b/tests/test_backbones/test_utils.py index 05e59785..1eb578b3 100644 --- a/tests/test_backbones/test_utils.py +++ b/tests/test_backbones/test_utils.py @@ -3,8 +3,9 @@ import torch from torch.nn.modules import GroupNorm from torch.nn.modules.batchnorm import _BatchNorm -from mmcls.models.utils import (BatchMixupLayer, InvertedResidual, SELayer, - channel_shuffle, make_divisible) +from mmcls.models.utils import (BatchCutMixLayer, BatchMixupLayer, + InvertedResidual, SELayer, channel_shuffle, + make_divisible) def is_norm(modules): @@ -127,3 +128,16 @@ def test_mixup(): mixed_img, mixed_label = mixup_layer(img, label) assert mixed_img.shape == torch.Size((16, 3, 32, 32)) assert mixed_label.shape == torch.Size((16, num_classes)) + + +def test_cutmix(): + + alpha = 1.0 + num_classes = 10 + cutmix_prob = 1.0 + img = torch.randn(16, 3, 32, 32) + label = torch.randint(0, 10, (16, )) + mixup_layer = BatchCutMixLayer(alpha, num_classes, cutmix_prob) + mixed_img, mixed_label = mixup_layer(img, label) + assert mixed_img.shape == torch.Size((16, 3, 32, 32)) + assert mixed_label.shape == torch.Size((16, num_classes)) diff --git a/tests/test_classifiers.py b/tests/test_classifiers.py index feab857a..639d6c61 100644 --- a/tests/test_classifiers.py +++ b/tests/test_classifiers.py @@ -30,6 +30,34 @@ def test_image_classifier(): assert losses['loss'].item() > 0 +def test_image_classifier_with_cutmix(): + + # Test cutmix in ImageClassifier + model_cfg = dict( + backbone=dict( + type='ResNet_CIFAR', + depth=50, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='MultiLabelLinearClsHead', + num_classes=10, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0, + use_soft=True)), + train_cfg=dict( + cutmix=dict(alpha=1.0, num_classes=10, cutmix_prob=1.0))) + img_classifier = ImageClassifier(**model_cfg) + img_classifier.init_weights() + imgs = torch.randn(16, 3, 32, 32) + label = torch.randint(0, 10, (16, )) + + losses = img_classifier.forward_train(imgs, label) + assert losses['loss'].item() > 0 + + def test_image_classifier_with_label_smooth_loss(): # Test mixup in ImageClassifier