25 lines
720 B
Python
25 lines
720 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def one_hot_encoding(gt, num_classes):
|
|
"""Change gt_label to one_hot encoding.
|
|
|
|
If the shape has 2 or more
|
|
dimensions, return it without encoding.
|
|
Args:
|
|
gt (Tensor): The gt label with shape (N,) or shape (N, */).
|
|
num_classes (int): The number of classes.
|
|
Return:
|
|
Tensor: One hot gt label.
|
|
"""
|
|
if gt.ndim == 1:
|
|
# multi-class classification
|
|
return F.one_hot(gt, num_classes=num_classes)
|
|
else:
|
|
# binary classification
|
|
# example. [[0], [1], [1]]
|
|
# multi-label classification
|
|
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
|
|
return gt
|