46 lines
1.6 KiB
Python
46 lines
1.6 KiB
Python
from copy import deepcopy
|
|
|
|
from mmengine.hub import get_config
|
|
from mmengine.model import BaseTTAModel
|
|
from mmengine.registry import MODELS
|
|
from mmengine.runner import Runner
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ClsTTAModel(BaseTTAModel):
|
|
|
|
def merge_preds(self, data_samples_list):
|
|
merged_data_samples = []
|
|
for data_samples in data_samples_list:
|
|
merged_data_samples.append(self._merge_single_sample(data_samples))
|
|
return merged_data_samples
|
|
|
|
def _merge_single_sample(self, data_samples):
|
|
merged_data_sample = data_samples[0].new()
|
|
merged_score = sum(data_sample.pred_label.score
|
|
for data_sample in data_samples) / len(data_samples)
|
|
merged_data_sample.set_pred_score(merged_score)
|
|
return merged_data_sample
|
|
|
|
|
|
if __name__ == '__main__':
|
|
cfg = get_config('mmcls::resnet/resnet50_8xb16_cifar10.py')
|
|
cfg.work_dir = 'work_dirs/resnet50_8xb16_cifar10'
|
|
cfg.model = dict(type='ClsTTAModel', module=cfg.model)
|
|
test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline)
|
|
flip_tta = dict(
|
|
type='TestTimeAug',
|
|
transforms=[
|
|
[
|
|
dict(type='RandomFlip', prob=1.),
|
|
dict(type='RandomFlip', prob=0.)
|
|
],
|
|
[test_pipeline[-1]],
|
|
])
|
|
# Replace the last transform with `TestTimeAug`
|
|
cfg.test_dataloader.dataset.pipeline[-1] = flip_tta
|
|
cfg.load_from = 'https://download.openmmlab.com/mmclassification/v0' \
|
|
'/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth'
|
|
runner = Runner.from_cfg(cfg)
|
|
runner.test()
|