mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Enhance] Suport Mixup&Cutmix for multi-label task.
This commit is contained in:
parent
b39885d953
commit
d29037e8d1
@ -3,9 +3,9 @@ from abc import ABCMeta, abstractmethod
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .builder import AUGMENT
|
from .builder import AUGMENT
|
||||||
|
from .utils import one_hot_encoding
|
||||||
|
|
||||||
|
|
||||||
class BaseCutMixLayer(object, metaclass=ABCMeta):
|
class BaseCutMixLayer(object, metaclass=ABCMeta):
|
||||||
@ -123,7 +123,7 @@ class BatchCutMixLayer(BaseCutMixLayer):
|
|||||||
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
|
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def cutmix(self, img, gt_label):
|
def cutmix(self, img, gt_label):
|
||||||
one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes)
|
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
|
||||||
lam = np.random.beta(self.alpha, self.alpha)
|
lam = np.random.beta(self.alpha, self.alpha)
|
||||||
batch_size = img.size(0)
|
batch_size = img.size(0)
|
||||||
index = torch.randperm(batch_size)
|
index = torch.randperm(batch_size)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .builder import AUGMENT
|
from .builder import AUGMENT
|
||||||
|
from .utils import one_hot_encoding
|
||||||
|
|
||||||
|
|
||||||
@AUGMENT.register_module(name='Identity')
|
@AUGMENT.register_module(name='Identity')
|
||||||
@ -24,7 +23,7 @@ class Identity(object):
|
|||||||
self.prob = prob
|
self.prob = prob
|
||||||
|
|
||||||
def one_hot(self, gt_label):
|
def one_hot(self, gt_label):
|
||||||
return F.one_hot(gt_label, num_classes=self.num_classes)
|
return one_hot_encoding(gt_label, self.num_classes)
|
||||||
|
|
||||||
def __call__(self, img, gt_label):
|
def __call__(self, img, gt_label):
|
||||||
return img, self.one_hot(gt_label)
|
return img, self.one_hot(gt_label)
|
||||||
|
@ -3,9 +3,9 @@ from abc import ABCMeta, abstractmethod
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .builder import AUGMENT
|
from .builder import AUGMENT
|
||||||
|
from .utils import one_hot_encoding
|
||||||
|
|
||||||
|
|
||||||
class BaseMixupLayer(object, metaclass=ABCMeta):
|
class BaseMixupLayer(object, metaclass=ABCMeta):
|
||||||
@ -42,7 +42,7 @@ class BatchMixupLayer(BaseMixupLayer):
|
|||||||
super(BatchMixupLayer, self).__init__(*args, **kwargs)
|
super(BatchMixupLayer, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def mixup(self, img, gt_label):
|
def mixup(self, img, gt_label):
|
||||||
one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes)
|
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
|
||||||
lam = np.random.beta(self.alpha, self.alpha)
|
lam = np.random.beta(self.alpha, self.alpha)
|
||||||
batch_size = img.size(0)
|
batch_size = img.size(0)
|
||||||
index = torch.randperm(batch_size)
|
index = torch.randperm(batch_size)
|
||||||
|
23
mmcls/models/utils/augment/utils.py
Normal file
23
mmcls/models/utils/augment/utils.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def one_hot_encoding(gt, num_classes):
|
||||||
|
"""Change gt_label to one_hot encoding.
|
||||||
|
|
||||||
|
If the shape has 2 or more
|
||||||
|
dimensions, return it without encoding.
|
||||||
|
Args:
|
||||||
|
gt (Tensor): The gt label with shape (N,) or shape (N, */).
|
||||||
|
num_classes (int): The number of classes.
|
||||||
|
Return:
|
||||||
|
Tensor: One hot gt label.
|
||||||
|
"""
|
||||||
|
if gt.ndim == 1:
|
||||||
|
# multi-class classification
|
||||||
|
return F.one_hot(gt, num_classes=num_classes)
|
||||||
|
else:
|
||||||
|
# binary classification
|
||||||
|
# example. [[0], [1], [1]]
|
||||||
|
# multi-label classification
|
||||||
|
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
|
||||||
|
return gt
|
@ -1,8 +1,15 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmcls.models.utils import Augments
|
from mmcls.models.utils import Augments
|
||||||
|
|
||||||
|
augment_cfgs = [
|
||||||
|
dict(type='BatchCutMix', alpha=1., prob=1.),
|
||||||
|
dict(type='BatchMixup', alpha=1., prob=1.),
|
||||||
|
dict(type='Identity', prob=1.),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_augments():
|
def test_augments():
|
||||||
imgs = torch.randn(4, 3, 32, 32)
|
imgs = torch.randn(4, 3, 32, 32)
|
||||||
@ -50,3 +57,31 @@ def test_augments():
|
|||||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||||
assert mixed_labels.shape == torch.Size((4, 10))
|
assert mixed_labels.shape == torch.Size((4, 10))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('cfg', augment_cfgs)
|
||||||
|
def test_binary_augment(cfg):
|
||||||
|
|
||||||
|
cfg_ = dict(num_classes=1, **cfg)
|
||||||
|
augs = Augments(cfg_)
|
||||||
|
|
||||||
|
imgs = torch.randn(4, 3, 32, 32)
|
||||||
|
labels = torch.randint(0, 2, (4, 1)).float()
|
||||||
|
|
||||||
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||||
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||||
|
assert mixed_labels.shape == torch.Size((4, 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('cfg', augment_cfgs)
|
||||||
|
def test_multilabel_augment(cfg):
|
||||||
|
|
||||||
|
cfg_ = dict(num_classes=10, **cfg)
|
||||||
|
augs = Augments(cfg_)
|
||||||
|
|
||||||
|
imgs = torch.randn(4, 3, 32, 32)
|
||||||
|
labels = torch.randint(0, 2, (4, 10)).float()
|
||||||
|
|
||||||
|
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||||
|
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||||
|
assert mixed_labels.shape == torch.Size((4, 10))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user