[Enhance] Use PyTorch official `one_hot` to implement `convert_to_one_hot`. (#696)
* some change to mmcls/models/losses/utils.py:convert_to_one_hot() * fixed problem: line too long * fixed wrong output shape * fixed lint PEP8 E128 * fix lint * fix lint * add unit tests Co-authored-by: Ezra-Yu <1105212286@qq.com>pull/717/head
parent
5f7322c211
commit
1214df083d
|
@ -114,8 +114,6 @@ def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
|
|||
"""
|
||||
assert (torch.max(targets).item() <
|
||||
classes), 'Class Index must be less than number of classes'
|
||||
one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
one_hot_targets = F.one_hot(
|
||||
targets.long().squeeze(-1), num_classes=classes)
|
||||
return one_hot_targets
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.losses.utils import convert_to_one_hot
|
||||
|
||||
|
||||
def ori_convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
|
||||
assert (torch.max(targets).item() <
|
||||
classes), 'Class Index must be less than number of classes'
|
||||
one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
return one_hot_targets
|
||||
|
||||
|
||||
def test_convert_to_one_hot():
|
||||
# label should smaller than classes
|
||||
targets = torch.tensor([1, 2, 3, 8, 5])
|
||||
classes = 5
|
||||
with pytest.raises(AssertionError):
|
||||
_ = convert_to_one_hot(targets, classes)
|
||||
|
||||
# test with original impl
|
||||
classes = 10
|
||||
targets = torch.randint(high=classes, size=(10, 1))
|
||||
ori_one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
ori_one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
one_hot_targets = convert_to_one_hot(targets, classes)
|
||||
assert torch.equal(ori_one_hot_targets, one_hot_targets)
|
||||
|
||||
|
||||
# test cuda version
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_convert_to_one_hot_cuda():
|
||||
# test with original impl
|
||||
classes = 10
|
||||
targets = torch.randint(high=classes, size=(10, 1)).cuda()
|
||||
ori_one_hot_targets = torch.zeros((targets.shape[0], classes),
|
||||
dtype=torch.long,
|
||||
device=targets.device)
|
||||
ori_one_hot_targets.scatter_(1, targets.long(), 1)
|
||||
one_hot_targets = convert_to_one_hot(targets, classes)
|
||||
assert torch.equal(ori_one_hot_targets, one_hot_targets)
|
||||
assert ori_one_hot_targets.device == one_hot_targets.device
|
Loading…
Reference in New Issue