Ma Zerun 97c4ae8805
[Improve] Update registries of mmcls. (#1306)
* [Improve] Update registries of mmcls.

* Update according to comments
2023-01-11 15:20:51 +08:00

65 lines
2.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from mmengine import ConfigDict
from mmcls.models import AverageClsScoreTTA, ImageClassifier
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
class TestAverageClsScoreTTA(TestCase):
DEFAULT_ARGS = dict(
type='AverageClsScoreTTA',
module=dict(
type='ImageClassifier',
backbone=dict(type='ResNet', depth=18),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=512,
loss=dict(type='CrossEntropyLoss'))))
def test_initialize(self):
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)
self.assertIsInstance(model.module, ImageClassifier)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)
# The forward of TTA model should not be called.
with self.assertRaisesRegex(NotImplementedError, 'will not be called'):
model(inputs)
def test_test_step(self):
cfg = ConfigDict(deepcopy(self.DEFAULT_ARGS))
cfg.module.data_preprocessor = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
model: AverageClsScoreTTA = MODELS.build(cfg)
img1 = torch.randint(0, 256, (1, 3, 224, 224))
img2 = torch.randint(0, 256, (1, 3, 224, 224))
data1 = {
'inputs': img1,
'data_samples': [ClsDataSample().set_gt_label(1)]
}
data2 = {
'inputs': img2,
'data_samples': [ClsDataSample().set_gt_label(1)]
}
data_tta = {
'inputs': [img1, img2],
'data_samples': [[ClsDataSample().set_gt_label(1)],
[ClsDataSample().set_gt_label(1)]]
}
score1 = model.module.test_step(data1)[0].pred_label.score
score2 = model.module.test_step(data2)[0].pred_label.score
score_tta = model.test_step(data_tta)[0].pred_label.score
torch.testing.assert_allclose(score_tta, (score1 + score2) / 2)