mmsegmentation/mmseg/apis/inference.py

211 lines
7.5 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2022-08-05 20:18:55 +08:00
import warnings
from collections import defaultdict
2022-08-05 20:18:55 +08:00
from pathlib import Path
from typing import Optional, Sequence, Union
2020-07-07 20:52:19 +08:00
import mmcv
import numpy as np
2020-07-07 20:52:19 +08:00
import torch
from mmengine import Config
from mmengine.dataset import Compose
2022-08-05 20:18:55 +08:00
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
2020-07-07 20:52:19 +08:00
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
2022-08-05 20:18:55 +08:00
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
from mmseg.visualization import SegLocalVisualizer
2020-07-07 20:52:19 +08:00
2022-08-05 20:18:55 +08:00
def init_model(config: Union[str, Path, Config],
checkpoint: Optional[str] = None,
device: str = 'cuda:0',
cfg_options: Optional[dict] = None):
2020-07-07 20:52:19 +08:00
"""Initialize a segmentor from config file.
Args:
2022-08-05 20:18:55 +08:00
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
:obj:`Path`, or the config object.
2020-07-07 20:52:19 +08:00
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
Use 'cpu' for loading model on CPU.
2022-08-05 20:18:55 +08:00
cfg_options (dict, optional): Options to override some settings in
the used config.
2020-07-07 20:52:19 +08:00
Returns:
nn.Module: The constructed segmentor.
"""
2022-08-05 20:18:55 +08:00
if isinstance(config, (str, Path)):
config = Config.fromfile(config)
2022-08-05 20:18:55 +08:00
elif not isinstance(config, Config):
2020-07-07 20:52:19 +08:00
raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config)))
2022-08-05 20:18:55 +08:00
if cfg_options is not None:
config.merge_from_dict(cfg_options)
elif 'init_cfg' in config.model.backbone:
config.model.backbone.init_cfg = None
2020-07-07 20:52:19 +08:00
config.model.pretrained = None
config.model.train_cfg = None
model = MODELS.build(config.model)
2020-07-07 20:52:19 +08:00
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
2022-08-05 20:18:55 +08:00
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmseg 1.x
model.dataset_meta = dataset_meta
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmseg 1.x
classes = checkpoint['meta']['CLASSES']
palette = checkpoint['meta']['PALETTE']
model.dataset_meta = {'classes': classes, 'palette': palette}
else:
warnings.simplefilter('once')
warnings.warn(
'dataset_meta or class names are not saved in the '
'checkpoint\'s meta data, classes and palette will be'
'set according to num_classes ')
num_classes = model.decode_head.num_classes
dataset_name = None
for name in dataset_aliases.keys():
if len(get_classes(name)) == num_classes:
dataset_name = name
break
if dataset_name is None:
warnings.warn(
'No suitable dataset found, use Cityscapes by default')
dataset_name = 'cityscapes'
model.dataset_meta = {
'classes': get_classes(dataset_name),
'palette': get_palette(dataset_name)
}
2020-07-07 20:52:19 +08:00
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
2020-07-07 20:52:19 +08:00
cfg = model.cfg
if dict(type='LoadAnnotations') in cfg.test_pipeline:
cfg.test_pipeline.remove(dict(type='LoadAnnotations'))
is_batch = True
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
is_batch = False
2020-07-07 20:52:19 +08:00
if isinstance(imgs[0], np.ndarray):
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
2020-07-07 20:52:19 +08:00
# TODO: Consider using the singleton pattern to avoid building
# a pipeline for each inference
pipeline = Compose(cfg.test_pipeline)
2020-07-07 20:52:19 +08:00
data = defaultdict(list)
for img in imgs:
if isinstance(img, np.ndarray):
data_ = dict(img=img)
2020-07-07 20:52:19 +08:00
else:
data_ = dict(img_path=img)
data_ = pipeline(data_)
data['inputs'].append(data_['inputs'])
data['data_samples'].append(data_['data_samples'])
return data, is_batch
2020-07-07 20:52:19 +08:00
def inference_model(model: BaseSegmentor,
img: ImageType) -> Union[SegDataSample, SampleList]:
2020-07-07 20:52:19 +08:00
"""Inference image(s) with the segmentor.
Args:
model (nn.Module): The loaded segmentor.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the segmentation results directly.
2020-07-07 20:52:19 +08:00
"""
# prepare data
data, is_batch = _preprare_data(img, model)
2020-07-07 20:52:19 +08:00
# forward the model
with torch.no_grad():
results = model.test_step(data)
return results if is_batch else results[0]
def show_result_pyplot(model: BaseSegmentor,
img: Union[str, np.ndarray],
result: SegDataSample,
opacity: float = 0.5,
title: str = '',
draw_gt: bool = True,
draw_pred: bool = True,
wait_time: float = 0,
show: bool = True,
save_dir=None,
out_file=None):
2020-07-07 20:52:19 +08:00
"""Visualize the segmentation results on the image.
Args:
model (nn.Module): The loaded segmentor.
img (str or np.ndarray): Image filename or loaded image.
result (SegDataSample): The prediction SegDataSample result.
opacity(float): Opacity of painted segmentation map.
Default 0.5. Must be in (0, 1] range.
title (str): The title of pyplot figure.
Default is ''.
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
draw_pred (bool): Whether to draw Prediction SegDataSample.
Defaults to True.
wait_time (float): The interval of show (s). 0 is the special value
that means "forever". Defaults to 0.
show (bool): Whether to display the drawn image.
Default to True.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
out_file (str, optional): Path to output file. Default to None.
Returns:
np.ndarray: the drawn image which channel is RGB.
2020-07-07 20:52:19 +08:00
"""
if hasattr(model, 'module'):
model = model.module
if isinstance(img, str):
image = mmcv.imread(img)
else:
image = img
if save_dir is not None:
mkdir_or_exist(save_dir)
# init visualizer
visualizer = SegLocalVisualizer(
vis_backends=[dict(type='LocalVisBackend')],
save_dir=save_dir,
alpha=opacity)
visualizer.dataset_meta = dict(
classes=model.dataset_meta['classes'],
palette=model.dataset_meta['palette'])
visualizer.add_datasample(
name=title,
image=image,
data_sample=result,
draw_gt=draw_gt,
draw_pred=draw_pred,
wait_time=wait_time,
out_file=out_file,
show=show)
vis_img = visualizer.get_image()
return vis_img