mmclassification/mmcls/models/utils/augment/utils.py

25 lines
720 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
def one_hot_encoding(gt, num_classes):
"""Change gt_label to one_hot encoding.
If the shape has 2 or more
dimensions, return it without encoding.
Args:
gt (Tensor): The gt label with shape (N,) or shape (N, */).
num_classes (int): The number of classes.
Return:
Tensor: One hot gt label.
"""
if gt.ndim == 1:
# multi-class classification
return F.one_hot(gt, num_classes=num_classes)
else:
# binary classification
# example. [[0], [1], [1]]
# multi-label classification
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
return gt