Update ClsDataSample design
parent
ce2b40133b
commit
98377df512
|
@ -44,7 +44,7 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
|
|||
if value.max() >= num_classes:
|
||||
raise ValueError(f'The label data ({value}) should not '
|
||||
f'exceed num_classes ({num_classes}).')
|
||||
label = LabelData(item=value, metainfo=metainfo)
|
||||
label = LabelData(label=value, metainfo=metainfo)
|
||||
return label
|
||||
|
||||
|
||||
|
@ -95,32 +95,83 @@ class ClsDataSample(BaseDataElement):
|
|||
META INFORMATION
|
||||
num_classes: 5
|
||||
DATA FIELDS
|
||||
item: tensor([0, 1, 4])
|
||||
label: tensor([0, 1, 4])
|
||||
) at 0x7fd7d1b41970>
|
||||
>>> # Convert to one-hot format
|
||||
>>> data_sample.gt_label.to_onehot()
|
||||
>>> print(data_sample.gt_label)
|
||||
>>> # Set one-hot format score
|
||||
>>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
>>> data_sample.set_pred_score(score)
|
||||
>>> print(data_sample.pred_label)
|
||||
<LabelData(
|
||||
META INFORMATION
|
||||
num_classes: 5
|
||||
DATA FIELDS
|
||||
item: tensor([1, 1, 0, 0, 1])
|
||||
score: tensor([0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
) at 0x7fd7d1b41970>
|
||||
"""
|
||||
|
||||
def set_gt_label(
|
||||
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number],
|
||||
Number]) -> None:
|
||||
"""Set the gt_label data."""
|
||||
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
|
||||
) -> 'ClsDataSample':
|
||||
"""Set label of ``gt_label``."""
|
||||
label = format_label(value, self.get('num_classes'))
|
||||
if 'gt_label' in self:
|
||||
self.gt_label.label = label.label
|
||||
else:
|
||||
self.gt_label = label
|
||||
return self
|
||||
|
||||
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample':
|
||||
"""Set score of ``gt_label``."""
|
||||
assert isinstance(value, torch.Tensor), \
|
||||
f'The value should be a torch.Tensor but got {type(value)}.'
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
if 'num_classes' in self:
|
||||
assert value.size(0) == self.num_classes, \
|
||||
f"The length of value ({value.size(0)}) doesn't "\
|
||||
f'match the num_classes ({self.num_classes}).'
|
||||
metainfo = {'num_classes': self.num_classes}
|
||||
else:
|
||||
metainfo = {'num_classes': value.size(0)}
|
||||
|
||||
if 'gt_label' in self:
|
||||
self.gt_label.score = value
|
||||
else:
|
||||
self.gt_label = LabelData(score=value, metainfo=metainfo)
|
||||
return self
|
||||
|
||||
def set_pred_label(
|
||||
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number],
|
||||
Number]) -> None:
|
||||
"""Set the pred_label data."""
|
||||
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
|
||||
) -> 'ClsDataSample':
|
||||
"""Set label of ``pred_label``."""
|
||||
label = format_label(value, self.get('num_classes'))
|
||||
if 'pred_label' in self:
|
||||
self.pred_label.label = label.label
|
||||
else:
|
||||
self.pred_label = label
|
||||
return self
|
||||
|
||||
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample':
|
||||
"""Set score of ``pred_label``."""
|
||||
assert isinstance(value, torch.Tensor), \
|
||||
f'The value should be a torch.Tensor but got {type(value)}.'
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
if 'num_classes' in self:
|
||||
assert value.size(0) == self.num_classes, \
|
||||
f"The length of value ({value.size(0)}) doesn't "\
|
||||
f'match the num_classes ({self.num_classes}).'
|
||||
metainfo = {'num_classes': self.num_classes}
|
||||
else:
|
||||
metainfo = {'num_classes': value.size(0)}
|
||||
|
||||
if 'pred_label' in self:
|
||||
self.pred_label.score = value
|
||||
else:
|
||||
self.pred_label = LabelData(score=value, metainfo=metainfo)
|
||||
return self
|
||||
|
||||
@property
|
||||
def gt_label(self):
|
||||
|
|
|
@ -18,50 +18,50 @@ class TestClsDataSample(TestCase):
|
|||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIsInstance(label.item, torch.LongTensor)
|
||||
self.assertIsInstance(label.label, torch.LongTensor)
|
||||
|
||||
# Test tensor with single number
|
||||
method(torch.tensor(2))
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIsInstance(label.item, torch.LongTensor)
|
||||
self.assertIsInstance(label.label, torch.LongTensor)
|
||||
|
||||
# Test array with single number
|
||||
method(np.array(3))
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIsInstance(label.item, torch.LongTensor)
|
||||
self.assertIsInstance(label.label, torch.LongTensor)
|
||||
|
||||
# Test tensor
|
||||
method(torch.tensor([1, 2, 3]))
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIsInstance(label.item, torch.Tensor)
|
||||
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all())
|
||||
self.assertIsInstance(label.label, torch.Tensor)
|
||||
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
|
||||
|
||||
# Test array
|
||||
method(np.array([1, 2, 3]))
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all())
|
||||
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
|
||||
|
||||
# Test Sequence
|
||||
method([1, 2, 3])
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all())
|
||||
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
|
||||
|
||||
# Test Sequence with float number
|
||||
method([0.2, 0, 0.8])
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertTrue((label.item == torch.tensor([0.2, 0, 0.8])).all())
|
||||
self.assertTrue((label.label == torch.tensor([0.2, 0, 0.8])).all())
|
||||
|
||||
# Test unavailable type
|
||||
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
|
||||
|
@ -102,3 +102,64 @@ class TestClsDataSample(TestCase):
|
|||
self.assertIn('pred_label', data_sample)
|
||||
del data_sample.pred_label
|
||||
self.assertNotIn('pred_label', data_sample)
|
||||
|
||||
def test_set_gt_score(self):
|
||||
data_sample = ClsDataSample(metainfo={'num_classes': 5})
|
||||
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertIn('score', data_sample.gt_label)
|
||||
torch.testing.assert_allclose(data_sample.gt_label.score,
|
||||
[0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
self.assertEqual(data_sample.gt_label.num_classes, 5)
|
||||
|
||||
# Test set again
|
||||
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
|
||||
torch.testing.assert_allclose(data_sample.gt_label.score,
|
||||
[0.2, 0.1, 0.5, 0.1, 0.1])
|
||||
|
||||
# Test invalid type
|
||||
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
|
||||
data_sample.set_gt_score([1, 2, 3])
|
||||
|
||||
# Test invalid dims
|
||||
with self.assertRaisesRegex(AssertionError, 'but got 2'):
|
||||
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
|
||||
|
||||
# Test invalid num_classes
|
||||
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
|
||||
data_sample.set_gt_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
|
||||
# Test auto inter num_classes
|
||||
data_sample = ClsDataSample()
|
||||
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertEqual(data_sample.gt_label.num_classes, 5)
|
||||
|
||||
def test_set_pred_score(self):
|
||||
data_sample = ClsDataSample(metainfo={'num_classes': 5})
|
||||
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertIn('score', data_sample.pred_label)
|
||||
torch.testing.assert_allclose(data_sample.pred_label.score,
|
||||
[0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
self.assertEqual(data_sample.pred_label.num_classes, 5)
|
||||
|
||||
# Test set again
|
||||
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
|
||||
torch.testing.assert_allclose(data_sample.pred_label.score,
|
||||
[0.2, 0.1, 0.5, 0.1, 0.1])
|
||||
|
||||
# Test invalid type
|
||||
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
|
||||
data_sample.set_pred_score([1, 2, 3])
|
||||
|
||||
# Test invalid dims
|
||||
with self.assertRaisesRegex(AssertionError, 'but got 2'):
|
||||
data_sample.set_pred_score(
|
||||
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
|
||||
|
||||
# Test invalid num_classes
|
||||
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
|
||||
data_sample.set_pred_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
|
||||
# Test auto inter num_classes
|
||||
data_sample = ClsDataSample()
|
||||
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertEqual(data_sample.pred_label.num_classes, 5)
|
||||
|
|
Loading…
Reference in New Issue