mmengine/examples/test_time_augmentation.py
Mashiro a9b6753fbe
Make TTAModel compatible with FSDP (#611)
* Add build_runner_with_tta and PrepareTTAHook

* rename hook file

* support build tta runner with runner type

* add unit test

* Add build_runner_with_tta to index.rst

* minor refine

* Add runner test cast

* Fix unit test

* fix unit test

* tmp save

* pop None if key does not exist

* Fix is_model_wrapper and force register class in test_runner

* [Fix] Fix is_model_wrapper

* destroy group after ut

* register module in testcase

* pass through unit test

* fix as comment

* remove breakpoint

* remove mmengine/testing/runner_test_cast.py

* minor refine

* minor refine

* minor refine

* set default data preprocessor for model

* minor refine

* minor refine

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix lint

* Fix unit test

* replace  with  in ImgDataPreprocessor

* Fix as comment

* add inference tutorial in advanced tutorial

* update index.rst

* add tta example

* refine tta tutorial

* Add english tutorial

* add note for build_runner_with_tta

* Fix as comment

* add examples

* remove chinese comment

* Update docs/en/advanced_tutorials/test_time_augmentation.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: RangiLyu <lyuchqi@gmail.com>
2022-12-27 21:39:43 +08:00

46 lines
1.6 KiB
Python

from copy import deepcopy
from mmengine.hub import get_config
from mmengine.model import BaseTTAModel
from mmengine.registry import MODELS
from mmengine.runner import Runner
@MODELS.register_module()
class ClsTTAModel(BaseTTAModel):
def merge_preds(self, data_samples_list):
merged_data_samples = []
for data_samples in data_samples_list:
merged_data_samples.append(self._merge_single_sample(data_samples))
return merged_data_samples
def _merge_single_sample(self, data_samples):
merged_data_sample = data_samples[0].new()
merged_score = sum(data_sample.pred_label.score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
return merged_data_sample
if __name__ == '__main__':
cfg = get_config('mmcls::resnet/resnet50_8xb16_cifar10.py')
cfg.work_dir = 'work_dirs/resnet50_8xb16_cifar10'
cfg.model = dict(type='ClsTTAModel', module=cfg.model)
test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
flip_tta = dict(
type='TestTimeAug',
transforms=[
[
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[test_pipeline[-1]],
])
# Replace the last transform with `TestTimeAug`
cfg.test_dataloader.dataset.pipeline[-1] = flip_tta
cfg.load_from = 'https://download.openmmlab.com/mmclassification/v0' \
'/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth'
runner = Runner.from_cfg(cfg)
runner.test()