mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
24 lines
672 B
Python
24 lines
672 B
Python
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
def one_hot_encoding(gt, num_classes):
|
||
|
"""Change gt_label to one_hot encoding.
|
||
|
|
||
|
If the shape has 2 or more
|
||
|
dimensions, return it without encoding.
|
||
|
Args:
|
||
|
gt (Tensor): The gt label with shape (N,) or shape (N, */).
|
||
|
num_classes (int): The number of classes.
|
||
|
Return:
|
||
|
Tensor: One hot gt label.
|
||
|
"""
|
||
|
if gt.ndim == 1:
|
||
|
# multi-class classification
|
||
|
return F.one_hot(gt, num_classes=num_classes)
|
||
|
else:
|
||
|
# binary classification
|
||
|
# example. [[0], [1], [1]]
|
||
|
# multi-label classification
|
||
|
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
|
||
|
return gt
|