mmsegmentation/tools/model_ensemble.py

122 lines
3.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import mmcv
import numpy as np
import torch
from mmcv.parallel import MMDataParallel
from mmcv.parallel.scatter_gather import scatter_kwargs
from mmcv.runner import load_checkpoint, wrap_fp16_model
from PIL import Image
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor
@torch.no_grad()
def main(args):
models = []
gpu_ids = args.gpus
configs = args.config
ckpts = args.checkpoint
cfg = mmcv.Config.fromfile(configs[0])
if args.aug_test:
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
]
cfg.data.test.pipeline[1].flip = True
else:
cfg.data.test.pipeline[1].img_ratios = [1.0]
cfg.data.test.pipeline[1].flip = False
torch.backends.cudnn.benchmark = True
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=4,
dist=False,
shuffle=False,
)
for idx, (config, ckpt) in enumerate(zip(configs, ckpts)):
cfg = mmcv.Config.fromfile(config)
cfg.model.pretrained = None
cfg.data.test.test_mode = True
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
if cfg.get('fp16', None):
wrap_fp16_model(model)
load_checkpoint(model, ckpt, map_location='cpu')
torch.cuda.empty_cache()
tmpdir = args.out
mmcv.mkdir_or_exist(tmpdir)
model = MMDataParallel(model, device_ids=[gpu_ids[idx % len(gpu_ids)]])
model.eval()
models.append(model)
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
loader_indices = data_loader.batch_sampler
for batch_indices, data in zip(loader_indices, data_loader):
result = []
for model in models:
x, _ = scatter_kwargs(
inputs=data, kwargs=None, target_gpus=model.device_ids)
if args.aug_test:
logits = model.module.aug_test_logits(**x[0])
else:
logits = model.module.simple_test_logits(**x[0])
result.append(logits)
result_logits = 0
for logit in result:
result_logits += logit
pred = result_logits.argmax(axis=1).squeeze()
img_info = dataset.img_infos[batch_indices[0]]
file_name = os.path.join(
tmpdir, img_info['ann']['seg_map'].split(os.path.sep)[-1])
Image.fromarray(pred.astype(np.uint8)).save(file_name)
prog_bar.update()
def parse_args():
parser = argparse.ArgumentParser(
description='Model Ensemble with logits result')
parser.add_argument(
'--config', type=str, nargs='+', help='ensemble config files path')
parser.add_argument(
'--checkpoint',
type=str,
nargs='+',
help='ensemble checkpoint files path')
parser.add_argument(
'--aug-test',
action='store_true',
help='control ensemble aug-result or single-result (default)')
parser.add_argument(
'--out', type=str, default='results', help='the dir to save result')
parser.add_argument(
'--gpus', type=int, nargs='+', default=[0], help='id of gpu to use')
args = parser.parse_args()
assert len(args.config) == len(args.checkpoint), \
f'len(config) must equal len(checkpoint), ' \
f'but len(config) = {len(args.config)} and' \
f'len(checkpoint) = {len(args.checkpoint)}'
assert args.out, "ensemble result out-dir can't be None"
return args
if __name__ == '__main__':
args = parse_args()
main(args)