mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* Rename data to structure * adjust the way to import module * adjust the way to import module * rename Structure to Data Structures in docs api * rename structure to structures * support using some modules of mmengine without torch * fix circleci config * fix circleci config * fix registry ut * minor fix * move init method from model/utils to model/weight_init.py * move init method from model/utils to model/weight_init.py * move sync_bn to model * move functions depending on torch to dl_utils * format import * fix logging ut * add weight init in model/__init__.py * move get_config and get_model to mmengine/hub * move log_processor.py to mmengine/runner * fix ut * Add TimeCounter in dl_utils/__init__.py
56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmengine.structures import LabelData
|
|
|
|
|
|
class TestLabelData(TestCase):
|
|
|
|
def test_label_to_onehot(self):
|
|
item = torch.tensor([1], dtype=torch.int64)
|
|
num_classes = 10
|
|
onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes)
|
|
assert tuple(onehot.shape) == (num_classes, )
|
|
assert onehot.device == item.device
|
|
# item is not onehot
|
|
with self.assertRaises(AssertionError):
|
|
LabelData.label_to_onehot(label='item', num_classes=num_classes)
|
|
|
|
# item'max bigger than num_classes
|
|
with self.assertRaises(AssertionError):
|
|
LabelData.label_to_onehot(
|
|
torch.tensor([11], dtype=torch.int64), num_classes)
|
|
onehot = LabelData.label_to_onehot(
|
|
label=torch.tensor([], dtype=torch.int64), num_classes=num_classes)
|
|
assert (onehot == torch.zeros((num_classes, ),
|
|
dtype=torch.int64)).all()
|
|
|
|
def test_onehot_to_label(self):
|
|
# item is not onehot
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'input is not one-hot and can not convert to label'):
|
|
LabelData.onehot_to_label(
|
|
onehot=torch.tensor([2], dtype=torch.int64))
|
|
|
|
with self.assertRaises(AssertionError):
|
|
LabelData.onehot_to_label(onehot='item')
|
|
|
|
item = torch.arange(0, 9)
|
|
onehot = LabelData.label_to_onehot(item, num_classes=10)
|
|
label = LabelData.onehot_to_label(onehot)
|
|
assert (label == item).all()
|
|
assert label.device == item.device
|
|
|
|
@pytest.mark.skipif(
|
|
not torch.cuda.is_available(), reason='GPU is required!')
|
|
def test_cuda(self):
|
|
item = torch.arange(0, 9).cuda()
|
|
onehot = LabelData.label_to_onehot(item, num_classes=10)
|
|
assert item.device == onehot.device
|
|
label = LabelData.onehot_to_label(onehot)
|
|
assert label.device == onehot.device
|