Update ClsDataSample design

pull/913/head
mzr1996 2022-05-13 16:55:36 +08:00
parent ce2b40133b
commit 98377df512
2 changed files with 134 additions and 22 deletions

View File

@ -44,7 +44,7 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
if value.max() >= num_classes: if value.max() >= num_classes:
raise ValueError(f'The label data ({value}) should not ' raise ValueError(f'The label data ({value}) should not '
f'exceed num_classes ({num_classes}).') f'exceed num_classes ({num_classes}).')
label = LabelData(item=value, metainfo=metainfo) label = LabelData(label=value, metainfo=metainfo)
return label return label
@ -95,32 +95,83 @@ class ClsDataSample(BaseDataElement):
META INFORMATION META INFORMATION
num_classes: 5 num_classes: 5
DATA FIELDS DATA FIELDS
item: tensor([0, 1, 4]) label: tensor([0, 1, 4])
) at 0x7fd7d1b41970> ) at 0x7fd7d1b41970>
>>> # Convert to one-hot format >>> # Set one-hot format score
>>> data_sample.gt_label.to_onehot() >>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])
>>> print(data_sample.gt_label) >>> data_sample.set_pred_score(score)
>>> print(data_sample.pred_label)
<LabelData( <LabelData(
META INFORMATION META INFORMATION
num_classes: 5 num_classes: 5
DATA FIELDS DATA FIELDS
item: tensor([1, 1, 0, 0, 1]) score: tensor([0.1, 0.1, 0.6, 0.1, 0.1])
) at 0x7fd7d1b41970> ) at 0x7fd7d1b41970>
""" """
def set_gt_label( def set_gt_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
Number]) -> None: ) -> 'ClsDataSample':
"""Set the gt_label data.""" """Set label of ``gt_label``."""
label = format_label(value, self.get('num_classes')) label = format_label(value, self.get('num_classes'))
if 'gt_label' in self:
self.gt_label.label = label.label
else:
self.gt_label = label self.gt_label = label
return self
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``gt_label``."""
assert isinstance(value, torch.Tensor), \
f'The value should be a torch.Tensor but got {type(value)}.'
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
if 'num_classes' in self:
assert value.size(0) == self.num_classes, \
f"The length of value ({value.size(0)}) doesn't "\
f'match the num_classes ({self.num_classes}).'
metainfo = {'num_classes': self.num_classes}
else:
metainfo = {'num_classes': value.size(0)}
if 'gt_label' in self:
self.gt_label.score = value
else:
self.gt_label = LabelData(score=value, metainfo=metainfo)
return self
def set_pred_label( def set_pred_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
Number]) -> None: ) -> 'ClsDataSample':
"""Set the pred_label data.""" """Set label of ``pred_label``."""
label = format_label(value, self.get('num_classes')) label = format_label(value, self.get('num_classes'))
if 'pred_label' in self:
self.pred_label.label = label.label
else:
self.pred_label = label self.pred_label = label
return self
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``pred_label``."""
assert isinstance(value, torch.Tensor), \
f'The value should be a torch.Tensor but got {type(value)}.'
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
if 'num_classes' in self:
assert value.size(0) == self.num_classes, \
f"The length of value ({value.size(0)}) doesn't "\
f'match the num_classes ({self.num_classes}).'
metainfo = {'num_classes': self.num_classes}
else:
metainfo = {'num_classes': value.size(0)}
if 'pred_label' in self:
self.pred_label.score = value
else:
self.pred_label = LabelData(score=value, metainfo=metainfo)
return self
@property @property
def gt_label(self): def gt_label(self):

View File

@ -18,50 +18,50 @@ class TestClsDataSample(TestCase):
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.item, torch.LongTensor) self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor with single number # Test tensor with single number
method(torch.tensor(2)) method(torch.tensor(2))
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.item, torch.LongTensor) self.assertIsInstance(label.label, torch.LongTensor)
# Test array with single number # Test array with single number
method(np.array(3)) method(np.array(3))
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.item, torch.LongTensor) self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor # Test tensor
method(torch.tensor([1, 2, 3])) method(torch.tensor([1, 2, 3]))
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.item, torch.Tensor) self.assertIsInstance(label.label, torch.Tensor)
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all()) self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test array # Test array
method(np.array([1, 2, 3])) method(np.array([1, 2, 3]))
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, LabelData)
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all()) self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test Sequence # Test Sequence
method([1, 2, 3]) method([1, 2, 3])
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, LabelData)
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all()) self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test Sequence with float number # Test Sequence with float number
method([0.2, 0, 0.8]) method([0.2, 0, 0.8])
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, LabelData)
self.assertTrue((label.item == torch.tensor([0.2, 0, 0.8])).all()) self.assertTrue((label.label == torch.tensor([0.2, 0, 0.8])).all())
# Test unavailable type # Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"): with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
@ -102,3 +102,64 @@ class TestClsDataSample(TestCase):
self.assertIn('pred_label', data_sample) self.assertIn('pred_label', data_sample)
del data_sample.pred_label del data_sample.pred_label
self.assertNotIn('pred_label', data_sample) self.assertNotIn('pred_label', data_sample)
def test_set_gt_score(self):
data_sample = ClsDataSample(metainfo={'num_classes': 5})
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.gt_label)
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
self.assertEqual(data_sample.gt_label.num_classes, 5)
# Test set again
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid type
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
data_sample.set_gt_score([1, 2, 3])
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
# Test invalid num_classes
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
data_sample.set_gt_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
# Test auto inter num_classes
data_sample = ClsDataSample()
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertEqual(data_sample.gt_label.num_classes, 5)
def test_set_pred_score(self):
data_sample = ClsDataSample(metainfo={'num_classes': 5})
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.pred_label)
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
self.assertEqual(data_sample.pred_label.num_classes, 5)
# Test set again
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid type
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
data_sample.set_pred_score([1, 2, 3])
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_pred_score(
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
# Test invalid num_classes
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
data_sample.set_pred_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
# Test auto inter num_classes
data_sample = ClsDataSample()
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertEqual(data_sample.pred_label.num_classes, 5)