from copy import deepcopy from mmengine.hub import get_config from mmengine.model import BaseTTAModel from mmengine.registry import MODELS from mmengine.runner import Runner @MODELS.register_module() class ClsTTAModel(BaseTTAModel): def merge_preds(self, data_samples_list): merged_data_samples = [] for data_samples in data_samples_list: merged_data_samples.append(self._merge_single_sample(data_samples)) return merged_data_samples def _merge_single_sample(self, data_samples): merged_data_sample = data_samples[0].new() merged_score = sum(data_sample.pred_label.score for data_sample in data_samples) / len(data_samples) merged_data_sample.set_pred_score(merged_score) return merged_data_sample if __name__ == '__main__': cfg = get_config('mmcls::resnet/resnet50_8xb16_cifar10.py') cfg.work_dir = 'work_dirs/resnet50_8xb16_cifar10' cfg.model = dict(type='ClsTTAModel', module=cfg.model) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline) flip_tta = dict( type='TestTimeAug', transforms=[ [ dict(type='RandomFlip', prob=1.), dict(type='RandomFlip', prob=0.) ], [test_pipeline[-1]], ]) # Replace the last transform with `TestTimeAug` cfg.test_dataloader.dataset.pipeline[-1] = flip_tta cfg.load_from = 'https://download.openmmlab.com/mmclassification/v0' \ '/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth' runner = Runner.from_cfg(cfg) runner.test()