Update ClsDataSample design

pull/913/head
mzr1996 2022-05-13 16:55:36 +08:00
parent ce2b40133b
commit 98377df512
2 changed files with 134 additions and 22 deletions

View File

@ -44,7 +44,7 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
if value.max() >= num_classes:
raise ValueError(f'The label data ({value}) should not '
f'exceed num_classes ({num_classes}).')
label = LabelData(item=value, metainfo=metainfo)
label = LabelData(label=value, metainfo=metainfo)
return label
@ -95,32 +95,83 @@ class ClsDataSample(BaseDataElement):
META INFORMATION
num_classes: 5
DATA FIELDS
item: tensor([0, 1, 4])
label: tensor([0, 1, 4])
) at 0x7fd7d1b41970>
>>> # Convert to one-hot format
>>> data_sample.gt_label.to_onehot()
>>> print(data_sample.gt_label)
>>> # Set one-hot format score
>>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])
>>> data_sample.set_pred_score(score)
>>> print(data_sample.pred_label)
<LabelData(
META INFORMATION
num_classes: 5
DATA FIELDS
item: tensor([1, 1, 0, 0, 1])
score: tensor([0.1, 0.1, 0.6, 0.1, 0.1])
) at 0x7fd7d1b41970>
"""
def set_gt_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number],
Number]) -> None:
"""Set the gt_label data."""
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ClsDataSample':
"""Set label of ``gt_label``."""
label = format_label(value, self.get('num_classes'))
self.gt_label = label
if 'gt_label' in self:
self.gt_label.label = label.label
else:
self.gt_label = label
return self
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``gt_label``."""
assert isinstance(value, torch.Tensor), \
f'The value should be a torch.Tensor but got {type(value)}.'
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
if 'num_classes' in self:
assert value.size(0) == self.num_classes, \
f"The length of value ({value.size(0)}) doesn't "\
f'match the num_classes ({self.num_classes}).'
metainfo = {'num_classes': self.num_classes}
else:
metainfo = {'num_classes': value.size(0)}
if 'gt_label' in self:
self.gt_label.score = value
else:
self.gt_label = LabelData(score=value, metainfo=metainfo)
return self
def set_pred_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number],
Number]) -> None:
"""Set the pred_label data."""
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ClsDataSample':
"""Set label of ``pred_label``."""
label = format_label(value, self.get('num_classes'))
self.pred_label = label
if 'pred_label' in self:
self.pred_label.label = label.label
else:
self.pred_label = label
return self
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``pred_label``."""
assert isinstance(value, torch.Tensor), \
f'The value should be a torch.Tensor but got {type(value)}.'
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
if 'num_classes' in self:
assert value.size(0) == self.num_classes, \
f"The length of value ({value.size(0)}) doesn't "\
f'match the num_classes ({self.num_classes}).'
metainfo = {'num_classes': self.num_classes}
else:
metainfo = {'num_classes': value.size(0)}
if 'pred_label' in self:
self.pred_label.score = value
else:
self.pred_label = LabelData(score=value, metainfo=metainfo)
return self
@property
def gt_label(self):

View File

@ -18,50 +18,50 @@ class TestClsDataSample(TestCase):
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.item, torch.LongTensor)
self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor with single number
method(torch.tensor(2))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.item, torch.LongTensor)
self.assertIsInstance(label.label, torch.LongTensor)
# Test array with single number
method(np.array(3))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.item, torch.LongTensor)
self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor
method(torch.tensor([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.item, torch.Tensor)
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all())
self.assertIsInstance(label.label, torch.Tensor)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test array
method(np.array([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all())
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test Sequence
method([1, 2, 3])
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertTrue((label.item == torch.tensor([1, 2, 3])).all())
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test Sequence with float number
method([0.2, 0, 0.8])
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertTrue((label.item == torch.tensor([0.2, 0, 0.8])).all())
self.assertTrue((label.label == torch.tensor([0.2, 0, 0.8])).all())
# Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
@ -102,3 +102,64 @@ class TestClsDataSample(TestCase):
self.assertIn('pred_label', data_sample)
del data_sample.pred_label
self.assertNotIn('pred_label', data_sample)
def test_set_gt_score(self):
data_sample = ClsDataSample(metainfo={'num_classes': 5})
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.gt_label)
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
self.assertEqual(data_sample.gt_label.num_classes, 5)
# Test set again
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid type
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
data_sample.set_gt_score([1, 2, 3])
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
# Test invalid num_classes
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
data_sample.set_gt_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
# Test auto inter num_classes
data_sample = ClsDataSample()
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertEqual(data_sample.gt_label.num_classes, 5)
def test_set_pred_score(self):
data_sample = ClsDataSample(metainfo={'num_classes': 5})
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.pred_label)
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
self.assertEqual(data_sample.pred_label.num_classes, 5)
# Test set again
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid type
with self.assertRaisesRegex(AssertionError, 'be a torch.Tensor'):
data_sample.set_pred_score([1, 2, 3])
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_pred_score(
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
# Test invalid num_classes
with self.assertRaisesRegex(AssertionError, r'length of value \(4\)'):
data_sample.set_pred_score(torch.tensor([0.1, 0.2, 0.3, 0.4]))
# Test auto inter num_classes
data_sample = ClsDataSample()
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertEqual(data_sample.pred_label.num_classes, 5)