mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Enhance] Suport Mixup&Cutmix for multi-label task.
This commit is contained in:
parent
b39885d953
commit
d29037e8d1
@ -3,9 +3,9 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .builder import AUGMENT
|
||||
from .utils import one_hot_encoding
|
||||
|
||||
|
||||
class BaseCutMixLayer(object, metaclass=ABCMeta):
|
||||
@ -123,7 +123,7 @@ class BatchCutMixLayer(BaseCutMixLayer):
|
||||
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
|
||||
|
||||
def cutmix(self, img, gt_label):
|
||||
one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes)
|
||||
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
batch_size = img.size(0)
|
||||
index = torch.randperm(batch_size)
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .builder import AUGMENT
|
||||
from .utils import one_hot_encoding
|
||||
|
||||
|
||||
@AUGMENT.register_module(name='Identity')
|
||||
@ -24,7 +23,7 @@ class Identity(object):
|
||||
self.prob = prob
|
||||
|
||||
def one_hot(self, gt_label):
|
||||
return F.one_hot(gt_label, num_classes=self.num_classes)
|
||||
return one_hot_encoding(gt_label, self.num_classes)
|
||||
|
||||
def __call__(self, img, gt_label):
|
||||
return img, self.one_hot(gt_label)
|
||||
|
@ -3,9 +3,9 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .builder import AUGMENT
|
||||
from .utils import one_hot_encoding
|
||||
|
||||
|
||||
class BaseMixupLayer(object, metaclass=ABCMeta):
|
||||
@ -42,7 +42,7 @@ class BatchMixupLayer(BaseMixupLayer):
|
||||
super(BatchMixupLayer, self).__init__(*args, **kwargs)
|
||||
|
||||
def mixup(self, img, gt_label):
|
||||
one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes)
|
||||
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
|
||||
lam = np.random.beta(self.alpha, self.alpha)
|
||||
batch_size = img.size(0)
|
||||
index = torch.randperm(batch_size)
|
||||
|
23
mmcls/models/utils/augment/utils.py
Normal file
23
mmcls/models/utils/augment/utils.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def one_hot_encoding(gt, num_classes):
|
||||
"""Change gt_label to one_hot encoding.
|
||||
|
||||
If the shape has 2 or more
|
||||
dimensions, return it without encoding.
|
||||
Args:
|
||||
gt (Tensor): The gt label with shape (N,) or shape (N, */).
|
||||
num_classes (int): The number of classes.
|
||||
Return:
|
||||
Tensor: One hot gt label.
|
||||
"""
|
||||
if gt.ndim == 1:
|
||||
# multi-class classification
|
||||
return F.one_hot(gt, num_classes=num_classes)
|
||||
else:
|
||||
# binary classification
|
||||
# example. [[0], [1], [1]]
|
||||
# multi-label classification
|
||||
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
|
||||
return gt
|
@ -1,8 +1,15 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.utils import Augments
|
||||
|
||||
augment_cfgs = [
|
||||
dict(type='BatchCutMix', alpha=1., prob=1.),
|
||||
dict(type='BatchMixup', alpha=1., prob=1.),
|
||||
dict(type='Identity', prob=1.),
|
||||
]
|
||||
|
||||
|
||||
def test_augments():
|
||||
imgs = torch.randn(4, 3, 32, 32)
|
||||
@ -50,3 +57,31 @@ def test_augments():
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg', augment_cfgs)
|
||||
def test_binary_augment(cfg):
|
||||
|
||||
cfg_ = dict(num_classes=1, **cfg)
|
||||
augs = Augments(cfg_)
|
||||
|
||||
imgs = torch.randn(4, 3, 32, 32)
|
||||
labels = torch.randint(0, 2, (4, 1)).float()
|
||||
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg', augment_cfgs)
|
||||
def test_multilabel_augment(cfg):
|
||||
|
||||
cfg_ = dict(num_classes=10, **cfg)
|
||||
augs = Augments(cfg_)
|
||||
|
||||
imgs = torch.randn(4, 3, 32, 32)
|
||||
labels = torch.randint(0, 2, (4, 10)).float()
|
||||
|
||||
mixed_imgs, mixed_labels = augs(imgs, labels)
|
||||
assert mixed_imgs.shape == torch.Size((4, 3, 32, 32))
|
||||
assert mixed_labels.shape == torch.Size((4, 10))
|
||||
|
Loading…
x
Reference in New Issue
Block a user