[Enhance] Suport Mixup&Cutmix for multi-label task.

This commit is contained in:
takuoko 2022-01-21 12:30:58 +09:00 committed by GitHub
parent b39885d953
commit d29037e8d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 7 deletions

View File

@ -3,9 +3,9 @@ from abc import ABCMeta, abstractmethod
import numpy as np
import torch
import torch.nn.functional as F
from .builder import AUGMENT
from .utils import one_hot_encoding
class BaseCutMixLayer(object, metaclass=ABCMeta):
@ -123,7 +123,7 @@ class BatchCutMixLayer(BaseCutMixLayer):
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
def cutmix(self, img, gt_label):
one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes)
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
batch_size = img.size(0)
index = torch.randperm(batch_size)

View File

@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from .builder import AUGMENT
from .utils import one_hot_encoding
@AUGMENT.register_module(name='Identity')
@ -24,7 +23,7 @@ class Identity(object):
self.prob = prob
def one_hot(self, gt_label):
return F.one_hot(gt_label, num_classes=self.num_classes)
return one_hot_encoding(gt_label, self.num_classes)
def __call__(self, img, gt_label):
return img, self.one_hot(gt_label)

View File

@ -3,9 +3,9 @@ from abc import ABCMeta, abstractmethod
import numpy as np
import torch
import torch.nn.functional as F
from .builder import AUGMENT
from .utils import one_hot_encoding
class BaseMixupLayer(object, metaclass=ABCMeta):
@ -42,7 +42,7 @@ class BatchMixupLayer(BaseMixupLayer):
super(BatchMixupLayer, self).__init__(*args, **kwargs)
def mixup(self, img, gt_label):
one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes)
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
batch_size = img.size(0)
index = torch.randperm(batch_size)

View File

@ -0,0 +1,23 @@
import torch.nn.functional as F
def one_hot_encoding(gt, num_classes):
"""Change gt_label to one_hot encoding.
If the shape has 2 or more
dimensions, return it without encoding.
Args:
gt (Tensor): The gt label with shape (N,) or shape (N, */).
num_classes (int): The number of classes.
Return:
Tensor: One hot gt label.
"""
if gt.ndim == 1:
# multi-class classification
return F.one_hot(gt, num_classes=num_classes)
else:
# binary classification
# example. [[0], [1], [1]]
# multi-label classification
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
return gt

View File

@ -1,8 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcls.models.utils import Augments
augment_cfgs = [
dict(type='BatchCutMix', alpha=1., prob=1.),
dict(type='BatchMixup', alpha=1., prob=1.),
dict(type='Identity', prob=1.),
]
def test_augments():
imgs = torch.randn(4, 3, 32, 32)
@ -50,3 +57,31 @@ def test_augments():
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))
@pytest.mark.parametrize('cfg', augment_cfgs)
def test_binary_augment(cfg):
cfg_ = dict(num_classes=1, **cfg)
augs = Augments(cfg_)
imgs = torch.randn(4, 3, 32, 32)
labels = torch.randint(0, 2, (4, 1)).float()
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 1))
@pytest.mark.parametrize('cfg', augment_cfgs)
def test_multilabel_augment(cfg):
cfg_ = dict(num_classes=10, **cfg)
augs = Augments(cfg_)
imgs = torch.randn(4, 3, 32, 32)
labels = torch.randint(0, 2, (4, 10)).float()
mixed_imgs, mixed_labels = augs(imgs, labels)
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
assert mixed_labels.shape == torch.Size((4, 10))