122 lines
3.7 KiB
Python
122 lines
3.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import os
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.parallel import MMDataParallel
|
|
from mmcv.parallel.scatter_gather import scatter_kwargs
|
|
from mmcv.runner import load_checkpoint, wrap_fp16_model
|
|
from PIL import Image
|
|
|
|
from mmseg.datasets import build_dataloader, build_dataset
|
|
from mmseg.models import build_segmentor
|
|
|
|
|
|
@torch.no_grad()
|
|
def main(args):
|
|
|
|
models = []
|
|
gpu_ids = args.gpus
|
|
configs = args.config
|
|
ckpts = args.checkpoint
|
|
|
|
cfg = mmcv.Config.fromfile(configs[0])
|
|
|
|
if args.aug_test:
|
|
cfg.data.test.pipeline[1].img_ratios = [
|
|
0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
|
|
]
|
|
cfg.data.test.pipeline[1].flip = True
|
|
else:
|
|
cfg.data.test.pipeline[1].img_ratios = [1.0]
|
|
cfg.data.test.pipeline[1].flip = False
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
# build the dataloader
|
|
dataset = build_dataset(cfg.data.test)
|
|
data_loader = build_dataloader(
|
|
dataset,
|
|
samples_per_gpu=1,
|
|
workers_per_gpu=4,
|
|
dist=False,
|
|
shuffle=False,
|
|
)
|
|
|
|
for idx, (config, ckpt) in enumerate(zip(configs, ckpts)):
|
|
cfg = mmcv.Config.fromfile(config)
|
|
cfg.model.pretrained = None
|
|
cfg.data.test.test_mode = True
|
|
|
|
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
|
|
if cfg.get('fp16', None):
|
|
wrap_fp16_model(model)
|
|
load_checkpoint(model, ckpt, map_location='cpu')
|
|
torch.cuda.empty_cache()
|
|
tmpdir = args.out
|
|
mmcv.mkdir_or_exist(tmpdir)
|
|
model = MMDataParallel(model, device_ids=[gpu_ids[idx % len(gpu_ids)]])
|
|
model.eval()
|
|
models.append(model)
|
|
|
|
dataset = data_loader.dataset
|
|
prog_bar = mmcv.ProgressBar(len(dataset))
|
|
loader_indices = data_loader.batch_sampler
|
|
for batch_indices, data in zip(loader_indices, data_loader):
|
|
result = []
|
|
|
|
for model in models:
|
|
x, _ = scatter_kwargs(
|
|
inputs=data, kwargs=None, target_gpus=model.device_ids)
|
|
if args.aug_test:
|
|
logits = model.module.aug_test_logits(**x[0])
|
|
else:
|
|
logits = model.module.simple_test_logits(**x[0])
|
|
result.append(logits)
|
|
|
|
result_logits = 0
|
|
for logit in result:
|
|
result_logits += logit
|
|
|
|
pred = result_logits.argmax(axis=1).squeeze()
|
|
img_info = dataset.img_infos[batch_indices[0]]
|
|
file_name = os.path.join(
|
|
tmpdir, img_info['ann']['seg_map'].split(os.path.sep)[-1])
|
|
Image.fromarray(pred.astype(np.uint8)).save(file_name)
|
|
prog_bar.update()
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Model Ensemble with logits result')
|
|
parser.add_argument(
|
|
'--config', type=str, nargs='+', help='ensemble config files path')
|
|
parser.add_argument(
|
|
'--checkpoint',
|
|
type=str,
|
|
nargs='+',
|
|
help='ensemble checkpoint files path')
|
|
parser.add_argument(
|
|
'--aug-test',
|
|
action='store_true',
|
|
help='control ensemble aug-result or single-result (default)')
|
|
parser.add_argument(
|
|
'--out', type=str, default='results', help='the dir to save result')
|
|
parser.add_argument(
|
|
'--gpus', type=int, nargs='+', default=[0], help='id of gpu to use')
|
|
|
|
args = parser.parse_args()
|
|
assert len(args.config) == len(args.checkpoint), \
|
|
f'len(config) must equal len(checkpoint), ' \
|
|
f'but len(config) = {len(args.config)} and' \
|
|
f'len(checkpoint) = {len(args.checkpoint)}'
|
|
assert args.out, "ensemble result out-dir can't be None"
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
main(args)
|