# Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy from unittest import TestCase import torch from mmengine import ConfigDict from mmengine.registry import init_default_scope from mmpretrain.models import AverageClsScoreTTA, ImageClassifier from mmpretrain.registry import MODELS from mmpretrain.structures import DataSample init_default_scope('mmpretrain') class TestAverageClsScoreTTA(TestCase): DEFAULT_ARGS = dict( type='AverageClsScoreTTA', module=dict( type='ImageClassifier', backbone=dict(type='ResNet', depth=18), neck=dict(type='GlobalAveragePooling'), head=dict( type='LinearClsHead', num_classes=10, in_channels=512, loss=dict(type='CrossEntropyLoss')))) def test_initialize(self): model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS) self.assertIsInstance(model.module, ImageClassifier) def test_forward(self): inputs = torch.rand(1, 3, 224, 224) model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS) # The forward of TTA model should not be called. with self.assertRaisesRegex(NotImplementedError, 'will not be called'): model(inputs) def test_test_step(self): cfg = ConfigDict(deepcopy(self.DEFAULT_ARGS)) cfg.module.data_preprocessor = dict( mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]) model: AverageClsScoreTTA = MODELS.build(cfg) img1 = torch.randint(0, 256, (1, 3, 224, 224)) img2 = torch.randint(0, 256, (1, 3, 224, 224)) data1 = { 'inputs': img1, 'data_samples': [DataSample().set_gt_label(1)] } data2 = { 'inputs': img2, 'data_samples': [DataSample().set_gt_label(1)] } data_tta = { 'inputs': [img1, img2], 'data_samples': [[DataSample().set_gt_label(1)], [DataSample().set_gt_label(1)]] } score1 = model.module.test_step(data1)[0].pred_score score2 = model.module.test_step(data2)[0].pred_score score_tta = model.test_step(data_tta)[0].pred_score torch.testing.assert_allclose(score_tta, (score1 + score2) / 2)