85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import logging
|
|
import tempfile
|
|
from unittest.mock import MagicMock
|
|
|
|
import mmcv.runner as mmcv_runner
|
|
import pytest
|
|
import torch
|
|
from mmcv.runner import obj_from_dict
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from mmcls.core.hook import ClassNumCheckHook
|
|
from mmcls.models.heads.base_head import BaseHead
|
|
|
|
|
|
class ExampleDataset(Dataset):
|
|
|
|
def __init__(self, CLASSES):
|
|
self.CLASSES = CLASSES
|
|
|
|
def __getitem__(self, idx):
|
|
results = dict(img=torch.tensor([1]), img_metas=dict())
|
|
return results
|
|
|
|
def __len__(self):
|
|
return 1
|
|
|
|
|
|
class ExampleHead(BaseHead):
|
|
|
|
def __init__(self, init_cfg=None):
|
|
super(BaseHead, self).__init__(init_cfg)
|
|
self.num_classes = 4
|
|
|
|
def forward_train(self, x, gt_label=None, **kwargs):
|
|
pass
|
|
|
|
|
|
class ExampleModel(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
super(ExampleModel, self).__init__()
|
|
self.test_cfg = None
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.head = ExampleHead()
|
|
|
|
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
|
return img
|
|
|
|
def train_step(self, data_batch, optimizer):
|
|
loss = self.forward(**data_batch)
|
|
return dict(loss=loss)
|
|
|
|
|
|
@pytest.mark.parametrize('runner_type',
|
|
['EpochBasedRunner', 'IterBasedRunner'])
|
|
@pytest.mark.parametrize(
|
|
'CLASSES', [None, ('A', 'B', 'C', 'D', 'E'), ('A', 'B', 'C', 'D')])
|
|
def test_num_class_hook(runner_type, CLASSES):
|
|
test_dataset = ExampleDataset(CLASSES)
|
|
loader = DataLoader(test_dataset, batch_size=1)
|
|
model = ExampleModel()
|
|
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
|
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
|
dict(params=model.parameters()))
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
num_class_hook = ClassNumCheckHook()
|
|
logger_mock = MagicMock(spec=logging.Logger)
|
|
runner = getattr(mmcv_runner, runner_type)(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
work_dir=tmpdir,
|
|
logger=logger_mock,
|
|
max_epochs=1)
|
|
runner.register_hook(num_class_hook)
|
|
if CLASSES is None:
|
|
runner.run([loader], [('train', 1)], 1)
|
|
logger_mock.warning.assert_called()
|
|
elif len(CLASSES) != 4:
|
|
with pytest.raises(AssertionError):
|
|
runner.run([loader], [('train', 1)], 1)
|
|
else:
|
|
runner.run([loader], [('train', 1)], 1)
|