mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Docs] update data element tutorials (#431)
* structure tutorials * refine data element docs * modify introduce * fix comment * fix comment * fix comment
This commit is contained in:
parent
5a9ac09f28
commit
1fea82aad5
File diff suppressed because it is too large
Load Diff
@ -22,7 +22,7 @@ class LabelData(BaseDataElement):
|
||||
assert isinstance(onehot, torch.Tensor)
|
||||
if (onehot.ndim == 1 and onehot.max().item() <= 1
|
||||
and onehot.min().item() >= 0):
|
||||
return onehot.nonzero().squeeze()
|
||||
return onehot.nonzero().squeeze(-1)
|
||||
else:
|
||||
raise ValueError(
|
||||
'input is not one-hot and can not convert to label')
|
||||
|
@ -44,6 +44,11 @@ class TestLabelData(TestCase):
|
||||
label = LabelData.onehot_to_label(onehot)
|
||||
assert (label == item).all()
|
||||
assert label.device == item.device
|
||||
item = torch.tensor([2])
|
||||
onehot = LabelData.label_to_onehot(item, num_classes=10)
|
||||
label = LabelData.onehot_to_label(onehot)
|
||||
assert label == item
|
||||
assert label.device == item.device
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='GPU is required!')
|
||||
|
Loading…
x
Reference in New Issue
Block a user