mmpretrain/tests/test_models/test_tta.py

68 lines
2.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from mmengine import ConfigDict
from mmengine.registry import init_default_scope
from mmpretrain.models import AverageClsScoreTTA, ImageClassifier
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
init_default_scope('mmpretrain')
class TestAverageClsScoreTTA(TestCase):
DEFAULT_ARGS = dict(
type='AverageClsScoreTTA',
module=dict(
type='ImageClassifier',
backbone=dict(type='ResNet', depth=18),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=512,
loss=dict(type='CrossEntropyLoss'))))
def test_initialize(self):
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)
self.assertIsInstance(model.module, ImageClassifier)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)
# The forward of TTA model should not be called.
with self.assertRaisesRegex(NotImplementedError, 'will not be called'):
model(inputs)
def test_test_step(self):
cfg = ConfigDict(deepcopy(self.DEFAULT_ARGS))
cfg.module.data_preprocessor = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
model: AverageClsScoreTTA = MODELS.build(cfg)
img1 = torch.randint(0, 256, (1, 3, 224, 224))
img2 = torch.randint(0, 256, (1, 3, 224, 224))
data1 = {
'inputs': img1,
'data_samples': [DataSample().set_gt_label(1)]
}
data2 = {
'inputs': img2,
'data_samples': [DataSample().set_gt_label(1)]
}
data_tta = {
'inputs': [img1, img2],
'data_samples': [[DataSample().set_gt_label(1)],
[DataSample().set_gt_label(1)]]
}
score1 = model.module.test_step(data1)[0].pred_score
score2 = model.module.test_step(data2)[0].pred_score
score_tta = model.test_step(data_tta)[0].pred_score
torch.testing.assert_allclose(score_tta, (score1 + score2) / 2)