68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from copy import deepcopy
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
from mmengine import ConfigDict
|
|
from mmengine.registry import init_default_scope
|
|
|
|
from mmpretrain.models import AverageClsScoreTTA, ImageClassifier
|
|
from mmpretrain.registry import MODELS
|
|
from mmpretrain.structures import DataSample
|
|
|
|
init_default_scope('mmpretrain')
|
|
|
|
|
|
class TestAverageClsScoreTTA(TestCase):
|
|
DEFAULT_ARGS = dict(
|
|
type='AverageClsScoreTTA',
|
|
module=dict(
|
|
type='ImageClassifier',
|
|
backbone=dict(type='ResNet', depth=18),
|
|
neck=dict(type='GlobalAveragePooling'),
|
|
head=dict(
|
|
type='LinearClsHead',
|
|
num_classes=10,
|
|
in_channels=512,
|
|
loss=dict(type='CrossEntropyLoss'))))
|
|
|
|
def test_initialize(self):
|
|
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)
|
|
self.assertIsInstance(model.module, ImageClassifier)
|
|
|
|
def test_forward(self):
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)
|
|
|
|
# The forward of TTA model should not be called.
|
|
with self.assertRaisesRegex(NotImplementedError, 'will not be called'):
|
|
model(inputs)
|
|
|
|
def test_test_step(self):
|
|
cfg = ConfigDict(deepcopy(self.DEFAULT_ARGS))
|
|
cfg.module.data_preprocessor = dict(
|
|
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
|
|
model: AverageClsScoreTTA = MODELS.build(cfg)
|
|
|
|
img1 = torch.randint(0, 256, (1, 3, 224, 224))
|
|
img2 = torch.randint(0, 256, (1, 3, 224, 224))
|
|
data1 = {
|
|
'inputs': img1,
|
|
'data_samples': [DataSample().set_gt_label(1)]
|
|
}
|
|
data2 = {
|
|
'inputs': img2,
|
|
'data_samples': [DataSample().set_gt_label(1)]
|
|
}
|
|
data_tta = {
|
|
'inputs': [img1, img2],
|
|
'data_samples': [[DataSample().set_gt_label(1)],
|
|
[DataSample().set_gt_label(1)]]
|
|
}
|
|
|
|
score1 = model.module.test_step(data1)[0].pred_score
|
|
score2 = model.module.test_step(data2)[0].pred_score
|
|
score_tta = model.test_step(data_tta)[0].pred_score
|
|
|
|
torch.testing.assert_allclose(score_tta, (score1 + score2) / 2)
|