Force `label` in ClsDatasample long type.
parent
f3299b4ca2
commit
f0cab33e09
|
@ -30,9 +30,9 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
|
|||
value = int(value.item())
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value)
|
||||
value = torch.from_numpy(value).to(torch.long)
|
||||
elif isinstance(value, Sequence) and not mmcv.is_str(value):
|
||||
value = torch.tensor(value)
|
||||
value = torch.tensor(value).to(torch.long)
|
||||
elif isinstance(value, int):
|
||||
value = torch.LongTensor([value])
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
|
@ -44,7 +44,8 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
|
|||
if value.max() >= num_classes:
|
||||
raise ValueError(f'The label data ({value}) should not '
|
||||
f'exceed num_classes ({num_classes}).')
|
||||
label = LabelData(label=value, metainfo=metainfo)
|
||||
label = LabelData(metainfo=metainfo)
|
||||
label.set_field(value, 'label', torch.LongTensor)
|
||||
return label
|
||||
|
||||
|
||||
|
|
|
@ -56,13 +56,6 @@ class TestClsDataSample(TestCase):
|
|||
self.assertIsInstance(label, LabelData)
|
||||
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
|
||||
|
||||
# Test Sequence with float number
|
||||
method([0.2, 0, 0.8])
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertTrue((label.label == torch.tensor([0.2, 0, 0.8])).all())
|
||||
|
||||
# Test unavailable type
|
||||
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
|
||||
method('hi')
|
||||
|
|
Loading…
Reference in New Issue