mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* Add build_runner_with_tta and PrepareTTAHook * rename hook file * support build tta runner with runner type * add unit test * Add build_runner_with_tta to index.rst * minor refine * Add runner test cast * Fix unit test * fix unit test * tmp save * pop None if key does not exist * Fix is_model_wrapper and force register class in test_runner * [Fix] Fix is_model_wrapper * destroy group after ut * register module in testcase * pass through unit test * fix as comment * remove breakpoint * remove mmengine/testing/runner_test_cast.py * minor refine * minor refine * minor refine * set default data preprocessor for model * minor refine * minor refine Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * Fix unit test * replace with in ImgDataPreprocessor * Fix as comment * add inference tutorial in advanced tutorial * update index.rst * add tta example * refine tta tutorial * Add english tutorial * add note for build_runner_with_tta * Fix as comment * add examples * remove chinese comment * Update docs/en/advanced_tutorials/test_time_augmentation.md Co-authored-by: RangiLyu <lyuchqi@gmail.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: RangiLyu <lyuchqi@gmail.com>
46 lines
1.6 KiB
Python
46 lines
1.6 KiB
Python
from copy import deepcopy
|
|
|
|
from mmengine.hub import get_config
|
|
from mmengine.model import BaseTTAModel
|
|
from mmengine.registry import MODELS
|
|
from mmengine.runner import Runner
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ClsTTAModel(BaseTTAModel):
|
|
|
|
def merge_preds(self, data_samples_list):
|
|
merged_data_samples = []
|
|
for data_samples in data_samples_list:
|
|
merged_data_samples.append(self._merge_single_sample(data_samples))
|
|
return merged_data_samples
|
|
|
|
def _merge_single_sample(self, data_samples):
|
|
merged_data_sample = data_samples[0].new()
|
|
merged_score = sum(data_sample.pred_label.score
|
|
for data_sample in data_samples) / len(data_samples)
|
|
merged_data_sample.set_pred_score(merged_score)
|
|
return merged_data_sample
|
|
|
|
|
|
if __name__ == '__main__':
|
|
cfg = get_config('mmcls::resnet/resnet50_8xb16_cifar10.py')
|
|
cfg.work_dir = 'work_dirs/resnet50_8xb16_cifar10'
|
|
cfg.model = dict(type='ClsTTAModel', module=cfg.model)
|
|
test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
|
|
flip_tta = dict(
|
|
type='TestTimeAug',
|
|
transforms=[
|
|
[
|
|
dict(type='RandomFlip', prob=1.),
|
|
dict(type='RandomFlip', prob=0.)
|
|
],
|
|
[test_pipeline[-1]],
|
|
])
|
|
# Replace the last transform with `TestTimeAug`
|
|
cfg.test_dataloader.dataset.pipeline[-1] = flip_tta
|
|
cfg.load_from = 'https://download.openmmlab.com/mmclassification/v0' \
|
|
'/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth'
|
|
runner = Runner.from_cfg(cfg)
|
|
runner.test()
|