diff --git a/mmcls/models/utils/augment/cutmix.py b/mmcls/models/utils/augment/cutmix.py index 215e878d0..45d758dfe 100644 --- a/mmcls/models/utils/augment/cutmix.py +++ b/mmcls/models/utils/augment/cutmix.py @@ -3,9 +3,9 @@ from abc import ABCMeta, abstractmethod import numpy as np import torch -import torch.nn.functional as F from .builder import AUGMENT +from .utils import one_hot_encoding class BaseCutMixLayer(object, metaclass=ABCMeta): @@ -123,7 +123,7 @@ class BatchCutMixLayer(BaseCutMixLayer): super(BatchCutMixLayer, self).__init__(*args, **kwargs) def cutmix(self, img, gt_label): - one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes) + one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes) lam = np.random.beta(self.alpha, self.alpha) batch_size = img.size(0) index = torch.randperm(batch_size) diff --git a/mmcls/models/utils/augment/identity.py b/mmcls/models/utils/augment/identity.py index e676fc422..ae3a3df52 100644 --- a/mmcls/models/utils/augment/identity.py +++ b/mmcls/models/utils/augment/identity.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch.nn.functional as F - from .builder import AUGMENT +from .utils import one_hot_encoding @AUGMENT.register_module(name='Identity') @@ -24,7 +23,7 @@ class Identity(object): self.prob = prob def one_hot(self, gt_label): - return F.one_hot(gt_label, num_classes=self.num_classes) + return one_hot_encoding(gt_label, self.num_classes) def __call__(self, img, gt_label): return img, self.one_hot(gt_label) diff --git a/mmcls/models/utils/augment/mixup.py b/mmcls/models/utils/augment/mixup.py index 2d6cd2b53..17c20704c 100644 --- a/mmcls/models/utils/augment/mixup.py +++ b/mmcls/models/utils/augment/mixup.py @@ -3,9 +3,9 @@ from abc import ABCMeta, abstractmethod import numpy as np import torch -import torch.nn.functional as F from .builder import AUGMENT +from .utils import one_hot_encoding class BaseMixupLayer(object, metaclass=ABCMeta): @@ -42,7 +42,7 @@ class BatchMixupLayer(BaseMixupLayer): super(BatchMixupLayer, self).__init__(*args, **kwargs) def mixup(self, img, gt_label): - one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes) + one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes) lam = np.random.beta(self.alpha, self.alpha) batch_size = img.size(0) index = torch.randperm(batch_size) diff --git a/mmcls/models/utils/augment/utils.py b/mmcls/models/utils/augment/utils.py new file mode 100644 index 000000000..0544af3ec --- /dev/null +++ b/mmcls/models/utils/augment/utils.py @@ -0,0 +1,23 @@ +import torch.nn.functional as F + + +def one_hot_encoding(gt, num_classes): + """Change gt_label to one_hot encoding. + + If the shape has 2 or more + dimensions, return it without encoding. + Args: + gt (Tensor): The gt label with shape (N,) or shape (N, */). + num_classes (int): The number of classes. + Return: + Tensor: One hot gt label. + """ + if gt.ndim == 1: + # multi-class classification + return F.one_hot(gt, num_classes=num_classes) + else: + # binary classification + # example. [[0], [1], [1]] + # multi-label classification + # example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]] + return gt diff --git a/tests/test_models/test_utils/test_augment.py b/tests/test_models/test_utils/test_augment.py index dd7e1e0bb..a037ad5c3 100644 --- a/tests/test_models/test_utils/test_augment.py +++ b/tests/test_models/test_utils/test_augment.py @@ -1,8 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest import torch from mmcls.models.utils import Augments +augment_cfgs = [ + dict(type='BatchCutMix', alpha=1., prob=1.), + dict(type='BatchMixup', alpha=1., prob=1.), + dict(type='Identity', prob=1.), +] + def test_augments(): imgs = torch.randn(4, 3, 32, 32) @@ -50,3 +57,31 @@ def test_augments(): mixed_imgs, mixed_labels = augs(imgs, labels) assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) assert mixed_labels.shape == torch.Size((4, 10)) + + +@pytest.mark.parametrize('cfg', augment_cfgs) +def test_binary_augment(cfg): + + cfg_ = dict(num_classes=1, **cfg) + augs = Augments(cfg_) + + imgs = torch.randn(4, 3, 32, 32) + labels = torch.randint(0, 2, (4, 1)).float() + + mixed_imgs, mixed_labels = augs(imgs, labels) + assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) + assert mixed_labels.shape == torch.Size((4, 1)) + + +@pytest.mark.parametrize('cfg', augment_cfgs) +def test_multilabel_augment(cfg): + + cfg_ = dict(num_classes=10, **cfg) + augs = Augments(cfg_) + + imgs = torch.randn(4, 3, 32, 32) + labels = torch.randint(0, 2, (4, 10)).float() + + mixed_imgs, mixed_labels = augs(imgs, labels) + assert mixed_imgs.shape == torch.Size((4, 3, 32, 32)) + assert mixed_labels.shape == torch.Size((4, 10))