[Enhance] Added NumClassCheckHook and unit tests (#559)

* Added NumClassCheckHook and CI tests

* Added HOOKS local registry. NumClassCheckHook and unit test files redistribution.

* Extended hook for supporting IterRunner & EpochRunner. Extended unit test.

* Simplification of ClassNumCheckHook. Minor changes.
pull/597/head
Eduardo López 2021-12-08 11:15:05 +01:00 committed by GitHub
parent abd7001bd9
commit d232912391
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 160 additions and 0 deletions

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .class_num_check_hook import ClassNumCheckHook
__all__ = ['ClassNumCheckHook']

View File

@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved
from mmcv.runner import IterBasedRunner
from mmcv.runner.hooks import HOOKS, Hook
from mmcv.utils import is_seq_of
@HOOKS.register_module()
class ClassNumCheckHook(Hook):
def _check_head(self, runner, dataset):
"""Check whether the `num_classes` in head matches the length of
`CLASSES` in `dataset`.
Args:
runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object.
dataset (obj: `BaseDataset`): the dataset to check.
"""
model = runner.model
if dataset.CLASSES is None:
runner.logger.warning(
f'Please set `CLASSES` '
f'in the {dataset.__class__.__name__} and'
f'check if it is consistent with the `num_classes` '
f'of head')
else:
assert is_seq_of(dataset.CLASSES, str), \
(f'`CLASSES` in {dataset.__class__.__name__}'
f'should be a tuple of str.')
for name, module in model.named_modules():
if hasattr(module, 'num_classes'):
assert module.num_classes == len(dataset.CLASSES), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
f'{model.__class__.__name__} does not matches '
f'the length of `CLASSES` '
f'{len(dataset.CLASSES)}) in '
f'{dataset.__class__.__name__}')
def before_train_iter(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj: `IterBasedRunner`): Iter based Runner.
"""
if not isinstance(runner, IterBasedRunner):
return
self._check_head(runner, runner.data_loader._dataloader.dataset)
def before_val_iter(self, runner):
"""Check whether the eval dataset is compatible with head.
Args:
runner (obj:`IterBasedRunner`): Iter based Runner.
"""
if not isinstance(runner, IterBasedRunner):
return
self._check_head(runner, runner.data_loader._dataloader.dataset)
def before_train_epoch(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner, runner.data_loader.dataset)
def before_val_epoch(self, runner):
"""Check whether the eval dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner, runner.data_loader.dataset)

View File

@ -0,0 +1,83 @@
import logging
import tempfile
from unittest.mock import MagicMock
import mmcv.runner as mmcv_runner
import pytest
import torch
from mmcv.runner import obj_from_dict
from torch.utils.data import DataLoader, Dataset
from mmcls.core.hook import ClassNumCheckHook
from mmcls.models.heads.base_head import BaseHead
class ExampleDataset(Dataset):
def __init__(self, CLASSES):
self.CLASSES = CLASSES
def __getitem__(self, idx):
results = dict(img=torch.tensor([1]), img_metas=dict())
return results
def __len__(self):
return 1
class ExampleHead(BaseHead):
def __init__(self, init_cfg=None):
super(BaseHead, self).__init__(init_cfg)
self.num_classes = 4
def forward_train(self, x, gt_label=None, **kwargs):
pass
class ExampleModel(torch.nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.test_cfg = None
self.conv = torch.nn.Conv2d(3, 3, 3)
self.head = ExampleHead()
def forward(self, img, img_metas, test_mode=False, **kwargs):
return img
def train_step(self, data_batch, optimizer):
loss = self.forward(**data_batch)
return dict(loss=loss)
@pytest.mark.parametrize('runner_type',
['EpochBasedRunner', 'IterBasedRunner'])
@pytest.mark.parametrize(
'CLASSES', [None, ('A', 'B', 'C', 'D', 'E'), ('A', 'B', 'C', 'D')])
def test_num_class_hook(runner_type, CLASSES):
test_dataset = ExampleDataset(CLASSES)
loader = DataLoader(test_dataset, batch_size=1)
model = ExampleModel()
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=model.parameters()))
with tempfile.TemporaryDirectory() as tmpdir:
num_class_hook = ClassNumCheckHook()
logger_mock = MagicMock(spec=logging.Logger)
runner = getattr(mmcv_runner, runner_type)(
model=model,
optimizer=optimizer,
work_dir=tmpdir,
logger=logger_mock,
max_epochs=1)
runner.register_hook(num_class_hook)
if CLASSES is None:
runner.run([loader], [('train', 1)], 1)
logger_mock.warning.assert_called()
elif len(CLASSES) != 4:
with pytest.raises(AssertionError):
runner.run([loader], [('train', 1)], 1)
else:
runner.run([loader], [('train', 1)], 1)