2022-02-17 02:17:20 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2022-01-21 12:30:58 +09:00
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
def one_hot_encoding(gt, num_classes):
|
|
|
|
"""Change gt_label to one_hot encoding.
|
|
|
|
|
|
|
|
If the shape has 2 or more
|
|
|
|
dimensions, return it without encoding.
|
|
|
|
Args:
|
|
|
|
gt (Tensor): The gt label with shape (N,) or shape (N, */).
|
|
|
|
num_classes (int): The number of classes.
|
|
|
|
Return:
|
|
|
|
Tensor: One hot gt label.
|
|
|
|
"""
|
|
|
|
if gt.ndim == 1:
|
|
|
|
# multi-class classification
|
|
|
|
return F.one_hot(gt, num_classes=num_classes)
|
|
|
|
else:
|
|
|
|
# binary classification
|
|
|
|
# example. [[0], [1], [1]]
|
|
|
|
# multi-label classification
|
|
|
|
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
|
|
|
|
return gt
|