mmclassification/tests/test_structures/test_datasample.py

109 lines
3.9 KiB
Python
Raw Normal View History

2022-05-07 18:01:08 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmpretrain.structures import DataSample, MultiTaskDataSample
2022-05-07 18:01:08 +08:00
class TestDataSample(TestCase):
2022-05-07 18:01:08 +08:00
def _test_set_label(self, key):
data_sample = DataSample()
2022-05-07 18:01:08 +08:00
method = getattr(data_sample, 'set_' + key)
# Test number
method(1)
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, torch.LongTensor)
2022-05-07 18:01:08 +08:00
# Test tensor with single number
method(torch.tensor(2))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, torch.LongTensor)
2022-05-07 18:01:08 +08:00
# Test array with single number
method(np.array(3))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, torch.LongTensor)
2022-05-07 18:01:08 +08:00
# Test tensor
method(torch.tensor([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, torch.Tensor)
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
2022-05-07 18:01:08 +08:00
# Test array
method(np.array([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
2022-05-07 18:01:08 +08:00
# Test Sequence
method([1, 2, 3])
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
2022-05-07 18:01:08 +08:00
# Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
method('hi')
def test_set_gt_label(self):
self._test_set_label('gt_label')
def test_set_pred_label(self):
self._test_set_label('pred_label')
2022-05-10 21:08:29 +08:00
2022-05-13 16:55:36 +08:00
def test_set_gt_score(self):
data_sample = DataSample()
2022-05-13 16:55:36 +08:00
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('gt_score', data_sample)
torch.testing.assert_allclose(data_sample.gt_score,
2022-05-13 16:55:36 +08:00
[0.1, 0.1, 0.6, 0.1, 0.1])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
2022-05-13 16:55:36 +08:00
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
def test_set_pred_score(self):
data_sample = DataSample()
2022-05-13 16:55:36 +08:00
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('pred_score', data_sample)
torch.testing.assert_allclose(data_sample.pred_score,
2022-05-13 16:55:36 +08:00
[0.1, 0.1, 0.6, 0.1, 0.1])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
2022-05-13 16:55:36 +08:00
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_pred_score(
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
[Feature] Support Multi-task. (#1229) * unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Remove multi-task configs. Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com> Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz <maroineamil99@gmail.com> Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00
class TestMultiTaskDataSample(TestCase):
def test_multi_task_data_sample(self):
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample()
task_sample = DataSample().set_gt_label(gt_label['task1'])
[Feature] Support Multi-task. (#1229) * unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Remove multi-task configs. Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com> Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz <maroineamil99@gmail.com> Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = DataSample().set_gt_label(
[Feature] Support Multi-task. (#1229) * unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Remove multi-task configs. Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com> Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz <maroineamil99@gmail.com> Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00
gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name)
self.assertIsInstance(data_sample.task0, MultiTaskDataSample)
self.assertIsInstance(data_sample.task1, DataSample)
self.assertIsInstance(data_sample.task0.task00, DataSample)