mmclassification/mmcls/models/utils/augment/identity.py

30 lines
855 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AUGMENT
from .utils import one_hot_encoding
@AUGMENT.register_module(name='Identity')
class Identity(object):
"""Change gt_label to one_hot encoding and keep img as the same.
Args:
num_classes (int): The number of classes.
prob (float): MixUp probability. It should be in range [0, 1].
Default to 1.0
"""
def __init__(self, num_classes, prob=1.0):
super(Identity, self).__init__()
assert isinstance(num_classes, int)
assert isinstance(prob, float) and 0.0 <= prob <= 1.0
self.num_classes = num_classes
self.prob = prob
def one_hot(self, gt_label):
return one_hot_encoding(gt_label, self.num_classes)
def __call__(self, img, gt_label):
return img, self.one_hot(gt_label)