mmpretrain/tests/test_structures/test_datasample.py
Colle bac181f393
[Feature] Support Multi-task. (#1229)
* unit test for multi_task_head

* [Feature] MultiTaskHead (#628, #481)

* [Fix] lint for multi_task_head

* [Feature] Add `MultiTaskDataset` to support multi-task training.

* Update MultiTaskClsHead

* Update docs

* [CI] Add test mim CI. (#879)

* [Fix] Remove duplicated wide-resnet metafile.

* [Feature] Support MPS device. (#894)

* [Feature] Support MPS device.

* Add `auto_select_device`

* Add unit tests

* [Fix] Fix Albu crash bug. (#918)

* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning

* Fix common

* Using copy incase potential bug in multi-label tasks

* Improve coding

* Improve code logic

* Add unit test

* Fix typo

* Fix yapf

* Bump version to 0.23.2. (#937)

* [Improve] Use `forward_dummy` to calculate FLOPS. (#953)

* Update README

* [Docs] Fix typo for wrong reference. (#1036)

* [Doc] Fix typo in tutorial 2 (#1043)

* [Docs] Fix a typo in ImageClassifier (#1050)

* add mask to loss

* add another pipeline

* adpat the pipeline if there is no mask

* switch mask and task

* first version of multi data smaple

* fix problem with attribut by getattr

* rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label'

* training  without evaluation

* first version work

* add others metrics

* delete evaluation from dataset

* fix linter

* fix linter

* multi metrics

* first version of test

* change evaluate metric

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* add tests

* add test for multidatasample

* create a generic test

* create a generic test

* create a generic test

* change multi data sample

* correct test

* test

* add new test

* add test for dataset

* correct test

* correct test

* correct test

* correct test

* fix : #5

* run yapf

* fix linter

* fix linter

* fix linter

* fix isort

* fix isort

* fix docformmater

* fix docformmater

* fix linter

* fix linter

* fix data sample

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update data sample

* update head

* update head

* update multi data sample

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* update head

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix linter

* fix : #2

* fix : linter

* update multi head

* fix linter

* fix linter

* update data sample

* update data sample

* fix ; linter

* update test

* test pipeline

* update pipeline

* update test

* update dataset

* update dataset

* fix linter

* fix linter

* update formatting

* add test for multi-task-eval

* update formatting

* fix linter

* update test

* update

* add test

* update metrics

* update metrics

* add doc for functions

* fix linter

* training for multitask 1.x

* fix linter

* run flake8

* run linter

* update test

* add mask in evaluation

* update metric doc

* update metric doc

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update metric doc

* update metric doc

* Fix cannot import name MultiTaskDataSample

* fix test_datasets

* fix test_datasets

* fix linter

* add an example of multitask

* change name of configs dataset

* Refactor the multi-task support

* correct test and metric

* add test to multidatasample

* add test to multidatasample

* correct test

* correct metrics and clshead

* Update mmcls/models/heads/cls_head.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update cls_head.py documentation

* lint

* lint

* fix: lint

* fix linter

* add eval mask

* fix documentation

* fix: single_label.py back to 1.x

* Update mmcls/models/heads/multi_task_head.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Remove multi-task configs.

Co-authored-by: mzr1996 <mzr1996@163.com>
Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com>
Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: marouaneamz <maroineamil99@gmail.com>
Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00

142 lines
5.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.structures import LabelData
from mmcls.structures import ClsDataSample, MultiTaskDataSample
class TestClsDataSample(TestCase):
def _test_set_label(self, key):
data_sample = ClsDataSample()
method = getattr(data_sample, 'set_' + key)
# Test number
method(1)
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor with single number
method(torch.tensor(2))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
# Test array with single number
method(np.array(3))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor
method(torch.tensor([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.Tensor)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test array
method(np.array([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test Sequence
method([1, 2, 3])
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
method('hi')
def test_set_gt_label(self):
self._test_set_label('gt_label')
def test_set_pred_label(self):
self._test_set_label('pred_label')
def test_del_gt_label(self):
data_sample = ClsDataSample()
self.assertNotIn('gt_label', data_sample)
data_sample.set_gt_label(1)
self.assertIn('gt_label', data_sample)
del data_sample.gt_label
self.assertNotIn('gt_label', data_sample)
def test_del_pred_label(self):
data_sample = ClsDataSample()
self.assertNotIn('pred_label', data_sample)
data_sample.set_pred_label(1)
self.assertIn('pred_label', data_sample)
del data_sample.pred_label
self.assertNotIn('pred_label', data_sample)
def test_set_gt_score(self):
data_sample = ClsDataSample()
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.gt_label)
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
# Test set again
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
def test_set_pred_score(self):
data_sample = ClsDataSample()
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.pred_label)
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.1, 0.1, 0.6, 0.1, 0.1])
# Test set again
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
# Test invalid dims
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_pred_score(
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
class TestMultiTaskDataSample(TestCase):
def test_multi_task_data_sample(self):
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample()
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = ClsDataSample().set_gt_label(
gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name)
self.assertIsInstance(data_sample.task0, MultiTaskDataSample)
self.assertIsInstance(data_sample.task1, ClsDataSample)
self.assertIsInstance(data_sample.task0.task00, ClsDataSample)