Force `label` in ClsDatasample long type.

pull/913/head
mzr1996 2022-06-07 22:11:05 +08:00
parent f3299b4ca2
commit f0cab33e09
2 changed files with 4 additions and 10 deletions

View File

@ -30,9 +30,9 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
value = int(value.item())
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
value = torch.from_numpy(value).to(torch.long)
elif isinstance(value, Sequence) and not mmcv.is_str(value):
value = torch.tensor(value)
value = torch.tensor(value).to(torch.long)
elif isinstance(value, int):
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
@ -44,7 +44,8 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
if value.max() >= num_classes:
raise ValueError(f'The label data ({value}) should not '
f'exceed num_classes ({num_classes}).')
label = LabelData(label=value, metainfo=metainfo)
label = LabelData(metainfo=metainfo)
label.set_field(value, 'label', torch.LongTensor)
return label

View File

@ -56,13 +56,6 @@ class TestClsDataSample(TestCase):
self.assertIsInstance(label, LabelData)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test Sequence with float number
method([0.2, 0, 0.8])
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertTrue((label.label == torch.tensor([0.2, 0, 0.8])).all())
# Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
method('hi')