From d232912391c414aab74e1720e3d5e5f82fd399a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20L=C3=B3pez?= <86728694+elopezz@users.noreply.github.com> Date: Wed, 8 Dec 2021 11:15:05 +0100 Subject: [PATCH] [Enhance] Added NumClassCheckHook and unit tests (#559) * Added NumClassCheckHook and CI tests * Added HOOKS local registry. NumClassCheckHook and unit test files redistribution. * Extended hook for supporting IterRunner & EpochRunner. Extended unit test. * Simplification of ClassNumCheckHook. Minor changes. --- mmcls/core/hook/__init__.py | 4 ++ mmcls/core/hook/class_num_check_hook.py | 73 ++++++++++++++++++++ tests/test_runtime/test_num_class_hook.py | 83 +++++++++++++++++++++++ 3 files changed, 160 insertions(+) create mode 100644 mmcls/core/hook/__init__.py create mode 100644 mmcls/core/hook/class_num_check_hook.py create mode 100644 tests/test_runtime/test_num_class_hook.py diff --git a/mmcls/core/hook/__init__.py b/mmcls/core/hook/__init__.py new file mode 100644 index 00000000..6dde44a5 --- /dev/null +++ b/mmcls/core/hook/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .class_num_check_hook import ClassNumCheckHook + +__all__ = ['ClassNumCheckHook'] diff --git a/mmcls/core/hook/class_num_check_hook.py b/mmcls/core/hook/class_num_check_hook.py new file mode 100644 index 00000000..52c2c9a5 --- /dev/null +++ b/mmcls/core/hook/class_num_check_hook.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved +from mmcv.runner import IterBasedRunner +from mmcv.runner.hooks import HOOKS, Hook +from mmcv.utils import is_seq_of + + +@HOOKS.register_module() +class ClassNumCheckHook(Hook): + + def _check_head(self, runner, dataset): + """Check whether the `num_classes` in head matches the length of + `CLASSES` in `dataset`. + + Args: + runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object. + dataset (obj: `BaseDataset`): the dataset to check. + """ + model = runner.model + if dataset.CLASSES is None: + runner.logger.warning( + f'Please set `CLASSES` ' + f'in the {dataset.__class__.__name__} and' + f'check if it is consistent with the `num_classes` ' + f'of head') + else: + assert is_seq_of(dataset.CLASSES, str), \ + (f'`CLASSES` in {dataset.__class__.__name__}' + f'should be a tuple of str.') + for name, module in model.named_modules(): + if hasattr(module, 'num_classes'): + assert module.num_classes == len(dataset.CLASSES), \ + (f'The `num_classes` ({module.num_classes}) in ' + f'{module.__class__.__name__} of ' + f'{model.__class__.__name__} does not matches ' + f'the length of `CLASSES` ' + f'{len(dataset.CLASSES)}) in ' + f'{dataset.__class__.__name__}') + + def before_train_iter(self, runner): + """Check whether the training dataset is compatible with head. + + Args: + runner (obj: `IterBasedRunner`): Iter based Runner. + """ + if not isinstance(runner, IterBasedRunner): + return + self._check_head(runner, runner.data_loader._dataloader.dataset) + + def before_val_iter(self, runner): + """Check whether the eval dataset is compatible with head. + + Args: + runner (obj:`IterBasedRunner`): Iter based Runner. + """ + if not isinstance(runner, IterBasedRunner): + return + self._check_head(runner, runner.data_loader._dataloader.dataset) + + def before_train_epoch(self, runner): + """Check whether the training dataset is compatible with head. + + Args: + runner (obj:`EpochBasedRunner`): Epoch based Runner. + """ + self._check_head(runner, runner.data_loader.dataset) + + def before_val_epoch(self, runner): + """Check whether the eval dataset is compatible with head. + + Args: + runner (obj:`EpochBasedRunner`): Epoch based Runner. + """ + self._check_head(runner, runner.data_loader.dataset) diff --git a/tests/test_runtime/test_num_class_hook.py b/tests/test_runtime/test_num_class_hook.py new file mode 100644 index 00000000..b5ed432b --- /dev/null +++ b/tests/test_runtime/test_num_class_hook.py @@ -0,0 +1,83 @@ +import logging +import tempfile +from unittest.mock import MagicMock + +import mmcv.runner as mmcv_runner +import pytest +import torch +from mmcv.runner import obj_from_dict +from torch.utils.data import DataLoader, Dataset + +from mmcls.core.hook import ClassNumCheckHook +from mmcls.models.heads.base_head import BaseHead + + +class ExampleDataset(Dataset): + + def __init__(self, CLASSES): + self.CLASSES = CLASSES + + def __getitem__(self, idx): + results = dict(img=torch.tensor([1]), img_metas=dict()) + return results + + def __len__(self): + return 1 + + +class ExampleHead(BaseHead): + + def __init__(self, init_cfg=None): + super(BaseHead, self).__init__(init_cfg) + self.num_classes = 4 + + def forward_train(self, x, gt_label=None, **kwargs): + pass + + +class ExampleModel(torch.nn.Module): + + def __init__(self): + super(ExampleModel, self).__init__() + self.test_cfg = None + self.conv = torch.nn.Conv2d(3, 3, 3) + self.head = ExampleHead() + + def forward(self, img, img_metas, test_mode=False, **kwargs): + return img + + def train_step(self, data_batch, optimizer): + loss = self.forward(**data_batch) + return dict(loss=loss) + + +@pytest.mark.parametrize('runner_type', + ['EpochBasedRunner', 'IterBasedRunner']) +@pytest.mark.parametrize( + 'CLASSES', [None, ('A', 'B', 'C', 'D', 'E'), ('A', 'B', 'C', 'D')]) +def test_num_class_hook(runner_type, CLASSES): + test_dataset = ExampleDataset(CLASSES) + loader = DataLoader(test_dataset, batch_size=1) + model = ExampleModel() + optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) + optimizer = obj_from_dict(optim_cfg, torch.optim, + dict(params=model.parameters())) + + with tempfile.TemporaryDirectory() as tmpdir: + num_class_hook = ClassNumCheckHook() + logger_mock = MagicMock(spec=logging.Logger) + runner = getattr(mmcv_runner, runner_type)( + model=model, + optimizer=optimizer, + work_dir=tmpdir, + logger=logger_mock, + max_epochs=1) + runner.register_hook(num_class_hook) + if CLASSES is None: + runner.run([loader], [('train', 1)], 1) + logger_mock.warning.assert_called() + elif len(CLASSES) != 4: + with pytest.raises(AssertionError): + runner.run([loader], [('train', 1)], 1) + else: + runner.run([loader], [('train', 1)], 1)