mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Merge branch 'open-mmlab:master' into custom/face_occlusion
This commit is contained in:
commit
c222684c29
@ -424,3 +424,32 @@ result/pred_result.pkl \
|
||||
result/confusion_matrix \
|
||||
--show
|
||||
```
|
||||
|
||||
## Model ensemble
|
||||
|
||||
To complete the integration of prediction probabilities for multiple models, we provide 'tools/model_ensemble.py'
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
python tools/model_ensemble.py \
|
||||
--config ${CONFIG_FILE1} ${CONFIG_FILE2} ... \
|
||||
--checkpoint ${CHECKPOINT_FILE1} ${CHECKPOINT_FILE2} ...\
|
||||
--aug-test \
|
||||
--out ${OUTPUT_DIR}\
|
||||
--gpus ${GPU_USED}\
|
||||
```
|
||||
|
||||
### Description of all arguments
|
||||
|
||||
- `--config`: Path to the config file for the ensemble model
|
||||
- `--checkpoint`: Path to the checkpoint file for the ensemble model
|
||||
- `--aug-test`: Whether to use flip and multi-scale test
|
||||
- `--out`: Save folder for model ensemble results
|
||||
- `--gpus`: Gpu-id used for model ensemble
|
||||
|
||||
### Result of model ensemble
|
||||
|
||||
- The model ensemble will generate an unrendered segmentation mask for each input, the input shape is `[H, W]`, the segmentation mask shape is `[H, W]`, and each pixel-value in the segmentation mask represents the pixel category after segmentation at that position.
|
||||
|
||||
- The filename of the model ensemble result will be named in the same filename as `Ground Truth`. If the filename of `Ground Truth` is called `1.png`, the model ensemble result file will also be named `1.png` and placed in the folder specified by `--out`.
|
||||
|
@ -366,3 +366,31 @@ configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py \
|
||||
checkpoint/fcn_r50-d8_512x1024_40k_cityscapes_20200604_192608-efe53f0d.pth \
|
||||
fcn
|
||||
```
|
||||
|
||||
## 模型集成
|
||||
|
||||
我们提供了`tools/model_ensemble.py` 完成对多个模型的预测概率进行集成的脚本
|
||||
|
||||
### 使用方法
|
||||
|
||||
```bash
|
||||
python tools/model_ensemble.py \
|
||||
--config ${CONFIG_FILE1} ${CONFIG_FILE2} ... \
|
||||
--checkpoint ${CHECKPOINT_FILE1} ${CHECKPOINT_FILE2} ...\
|
||||
--aug-test \
|
||||
--out ${OUTPUT_DIR}\
|
||||
--gpus ${GPU_USED}\
|
||||
```
|
||||
|
||||
### 各个参数的描述:
|
||||
|
||||
- `--config`: 集成模型的配置文件的路径
|
||||
- `--checkpoint`: 集成模型的权重文件的路径
|
||||
- `--aug-test`: 是否使用翻转和多尺度预测
|
||||
- `--out`: 模型集成结果的保存文件夹路径
|
||||
- `--gpus`: 模型集成使用的gpu-id
|
||||
|
||||
### 模型集成结果
|
||||
|
||||
- 模型集成会对每一张输入,形状为`[H, W]`,产生一张未渲染的分割掩膜文件(segmentation mask),形状为`[H, W]`,分割掩膜中的每个像素点的值代表该位置分割后的像素类别.
|
||||
- 模型集成结果的文件名会采用和`Ground Truth`一致的文件命名,如`Ground Truth`文件名称为`1.png`,则模型集成结果文件也会被命名为`1.png`,并放置在`--out`指定的文件夹中.
|
||||
|
@ -278,6 +278,15 @@ class EncoderDecoder(BaseSegmentor):
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
|
||||
def simple_test_logits(self, img, img_metas, rescale=True):
|
||||
"""Test without augmentations.
|
||||
|
||||
Return numpy seg_map logits.
|
||||
"""
|
||||
seg_logit = self.inference(img[0], img_metas[0], rescale)
|
||||
seg_logit = seg_logit.cpu().numpy()
|
||||
return seg_logit
|
||||
|
||||
def aug_test(self, imgs, img_metas, rescale=True):
|
||||
"""Test with augmentations.
|
||||
|
||||
@ -300,3 +309,21 @@ class EncoderDecoder(BaseSegmentor):
|
||||
# unravel batch dim
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
|
||||
def aug_test_logits(self, img, img_metas, rescale=True):
|
||||
"""Test with augmentations.
|
||||
|
||||
Return seg_map logits. Only rescale=True is supported.
|
||||
"""
|
||||
# aug_test rescale all imgs back to ori_shape for now
|
||||
assert rescale
|
||||
|
||||
imgs = img
|
||||
seg_logit = self.inference(imgs[0], img_metas[0], rescale)
|
||||
for i in range(1, len(imgs)):
|
||||
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
|
||||
seg_logit += cur_seg_logit
|
||||
|
||||
seg_logit /= len(imgs)
|
||||
seg_logit = seg_logit.cpu().numpy()
|
||||
return seg_logit
|
||||
|
121
tools/model_ensemble.py
Normal file
121
tools/model_ensemble.py
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.parallel.scatter_gather import scatter_kwargs
|
||||
from mmcv.runner import load_checkpoint, wrap_fp16_model
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
|
||||
models = []
|
||||
gpu_ids = args.gpus
|
||||
configs = args.config
|
||||
ckpts = args.checkpoint
|
||||
|
||||
cfg = mmcv.Config.fromfile(configs[0])
|
||||
|
||||
if args.aug_test:
|
||||
cfg.data.test.pipeline[1].img_ratios = [
|
||||
0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
|
||||
]
|
||||
cfg.data.test.pipeline[1].flip = True
|
||||
else:
|
||||
cfg.data.test.pipeline[1].img_ratios = [1.0]
|
||||
cfg.data.test.pipeline[1].flip = False
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# build the dataloader
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=1,
|
||||
workers_per_gpu=4,
|
||||
dist=False,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
for idx, (config, ckpt) in enumerate(zip(configs, ckpts)):
|
||||
cfg = mmcv.Config.fromfile(config)
|
||||
cfg.model.pretrained = None
|
||||
cfg.data.test.test_mode = True
|
||||
|
||||
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
|
||||
if cfg.get('fp16', None):
|
||||
wrap_fp16_model(model)
|
||||
load_checkpoint(model, ckpt, map_location='cpu')
|
||||
torch.cuda.empty_cache()
|
||||
tmpdir = args.out
|
||||
mmcv.mkdir_or_exist(tmpdir)
|
||||
model = MMDataParallel(model, device_ids=[gpu_ids[idx % len(gpu_ids)]])
|
||||
model.eval()
|
||||
models.append(model)
|
||||
|
||||
dataset = data_loader.dataset
|
||||
prog_bar = mmcv.ProgressBar(len(dataset))
|
||||
loader_indices = data_loader.batch_sampler
|
||||
for batch_indices, data in zip(loader_indices, data_loader):
|
||||
result = []
|
||||
|
||||
for model in models:
|
||||
x, _ = scatter_kwargs(
|
||||
inputs=data, kwargs=None, target_gpus=model.device_ids)
|
||||
if args.aug_test:
|
||||
logits = model.module.aug_test_logits(**x[0])
|
||||
else:
|
||||
logits = model.module.simple_test_logits(**x[0])
|
||||
result.append(logits)
|
||||
|
||||
result_logits = 0
|
||||
for logit in result:
|
||||
result_logits += logit
|
||||
|
||||
pred = result_logits.argmax(axis=1).squeeze()
|
||||
img_info = dataset.img_infos[batch_indices[0]]
|
||||
file_name = os.path.join(
|
||||
tmpdir, img_info['ann']['seg_map'].split(os.path.sep)[-1])
|
||||
Image.fromarray(pred.astype(np.uint8)).save(file_name)
|
||||
prog_bar.update()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Model Ensemble with logits result')
|
||||
parser.add_argument(
|
||||
'--config', type=str, nargs='+', help='ensemble config files path')
|
||||
parser.add_argument(
|
||||
'--checkpoint',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='ensemble checkpoint files path')
|
||||
parser.add_argument(
|
||||
'--aug-test',
|
||||
action='store_true',
|
||||
help='control ensemble aug-result or single-result (default)')
|
||||
parser.add_argument(
|
||||
'--out', type=str, default='results', help='the dir to save result')
|
||||
parser.add_argument(
|
||||
'--gpus', type=int, nargs='+', default=[0], help='id of gpu to use')
|
||||
|
||||
args = parser.parse_args()
|
||||
assert len(args.config) == len(args.checkpoint), \
|
||||
f'len(config) must equal len(checkpoint), ' \
|
||||
f'but len(config) = {len(args.config)} and' \
|
||||
f'len(checkpoint) = {len(args.checkpoint)}'
|
||||
assert args.out, "ensemble result out-dir can't be None"
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
Loading…
x
Reference in New Issue
Block a user