mmengine/examples/test_time_augmentation.py

46 lines
1.6 KiB
Python

from copy import deepcopy
from mmengine.hub import get_config
from mmengine.model import BaseTTAModel
from mmengine.registry import MODELS
from mmengine.runner import Runner
@MODELS.register_module()
class ClsTTAModel(BaseTTAModel):
def merge_preds(self, data_samples_list):
merged_data_samples = []
for data_samples in data_samples_list:
merged_data_samples.append(self._merge_single_sample(data_samples))
return merged_data_samples
def _merge_single_sample(self, data_samples):
merged_data_sample = data_samples[0].new()
merged_score = sum(data_sample.pred_label.score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
return merged_data_sample
if __name__ == '__main__':
cfg = get_config('mmcls::resnet/resnet50_8xb16_cifar10.py')
cfg.work_dir = 'work_dirs/resnet50_8xb16_cifar10'
cfg.model = dict(type='ClsTTAModel', module=cfg.model)
test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
flip_tta = dict(
type='TestTimeAug',
transforms=[
[
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[test_pipeline[-1]],
])
# Replace the last transform with `TestTimeAug`
cfg.test_dataloader.dataset.pipeline[-1] = flip_tta
cfg.load_from = 'https://download.openmmlab.com/mmclassification/v0' \
'/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth'
runner = Runner.from_cfg(cfg)
runner.test()