mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support MMSegInferencer (#2413)
## Motivation Support `MMSegInferencer` for providing an easy and clean interface for single or multiple images inferencing. Ref: https://github.com/open-mmlab/mmengine/pull/773 https://github.com/open-mmlab/mmocr/pull/1608 ## Modification - mmseg/apis/mmseg_inferencer.py - mmseg/visualization/local_visualizer.py - demo/image_demo_with_inferencer.py ## Use cases (Optional) Based on https://github.com/open-mmlab/mmengine/tree/inference Add a new image inference demo with `MMSegInferencer` - demo/image_demo_with_inferencer.py ```shell python demo/image_demo_with_inferencer.py demo/demo.png fcn_r50-d8_4xb2-40k_cityscapes-512x1024 ``` --------- Co-authored-by: MeowZheng <meowzheng@outlook.com>
This commit is contained in:
parent
039ba5d4ca
commit
53fe1ccf39
54
demo/image_demo_with_inferencer.py
Normal file
54
demo/image_demo_with_inferencer.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from mmseg.apis import MMSegInferencer
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('img', help='Image file')
|
||||
parser.add_argument('model', help='Config file')
|
||||
parser.add_argument('--checkpoint', default=None, help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--out-dir', default='', help='Path to save result file')
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Whether to display the drawn image.')
|
||||
parser.add_argument(
|
||||
'--save-mask',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Enable save the mask file')
|
||||
parser.add_argument(
|
||||
'--dataset-name',
|
||||
default='cityscapes',
|
||||
help='Color palette used for segmentation map')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
parser.add_argument(
|
||||
'--opacity',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='Opacity of painted segmentation map. In (0, 1] range.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
mmseg_inferencer = MMSegInferencer(
|
||||
args.model,
|
||||
args.checkpoint,
|
||||
dataset_name=args.dataset_name,
|
||||
device=args.device)
|
||||
|
||||
# test a single image
|
||||
mmseg_inferencer(
|
||||
args.img,
|
||||
show=args.show,
|
||||
out_dir=args.out_dir,
|
||||
save_mask=args.save_mask,
|
||||
opacity=args.opacity)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -9,7 +9,7 @@ from .version import __version__, version_info
|
||||
|
||||
MMCV_MIN = '2.0.0rc4'
|
||||
MMCV_MAX = '2.1.0'
|
||||
MMENGINE_MIN = '0.4.0'
|
||||
MMENGINE_MIN = '0.5.0'
|
||||
MMENGINE_MAX = '1.0.0'
|
||||
|
||||
|
||||
|
@ -1,4 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_model, init_model, show_result_pyplot
|
||||
from .mmseg_inferencer import MMSegInferencer
|
||||
|
||||
__all__ = ['init_model', 'inference_model', 'show_result_pyplot']
|
||||
__all__ = [
|
||||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer'
|
||||
]
|
||||
|
279
mmseg/apis/mmseg_inferencer.py
Normal file
279
mmseg/apis/mmseg_inferencer.py
Normal file
@ -0,0 +1,279 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList, register_all_modules
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
InputType = Union[str, np.ndarray]
|
||||
InputsType = Union[InputType, Sequence[InputType]]
|
||||
PredType = Union[SegDataSample, SampleList]
|
||||
|
||||
|
||||
class MMSegInferencer(BaseInferencer):
|
||||
"""Semantic segmentation inferencer, provides inference and visualization
|
||||
interfaces. Note: MMEngine >= 0.5.0 is required.
|
||||
|
||||
Args:
|
||||
model (str, optional): Path to the config file or the model name
|
||||
defined in metafile. For example, it could be
|
||||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024" or
|
||||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py"
|
||||
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||
and model is a model name of metafile, the weights will be loaded
|
||||
from metafile. Defaults to None.
|
||||
palette (List[List[int]], optional): The palette of
|
||||
segmentation map.
|
||||
classes (Tuple[str], optional): Category information.
|
||||
dataset_name (str, optional): Name of the datasets supported in mmseg.
|
||||
device (str, optional): Device to run inference. If None, the available
|
||||
device will be automatically used. Defaults to None.
|
||||
scope (str, optional): The scope of the model. Defaults to None.
|
||||
"""
|
||||
|
||||
preprocess_kwargs: set = set()
|
||||
forward_kwargs: set = {'mode', 'out_dir'}
|
||||
visualize_kwargs: set = {
|
||||
'show', 'wait_time', 'draw_pred', 'img_out_dir', 'opacity'
|
||||
}
|
||||
postprocess_kwargs: set = {
|
||||
'pred_out_dir', 'return_datasample', 'save_mask', 'mask_dir'
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
model: Union[ModelType, str],
|
||||
weights: Optional[str] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
scope: Optional[str] = 'mmseg') -> None:
|
||||
# A global counter tracking the number of images processes, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
register_all_modules()
|
||||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
|
||||
assert isinstance(self.visualizer, SegLocalVisualizer)
|
||||
self.visualizer.set_dataset_meta(palette, classes, dataset_name)
|
||||
|
||||
def __call__(self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
out_dir: str = '',
|
||||
save_mask: bool = False,
|
||||
mask_dir: str = 'mask',
|
||||
**kwargs) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
inputs (Union[str, np.ndarray]): Inputs for the inferencer.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`SegDataSample`. Defaults to False.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
out_dir (str): Output directory of inference results. Defaults: ''.
|
||||
save_mask (bool): Whether save pred mask as a file.
|
||||
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
|
||||
mask file.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
"""
|
||||
return super().__call__(
|
||||
inputs=inputs,
|
||||
return_datasamples=return_datasamples,
|
||||
batch_size=batch_size,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_pred=draw_pred,
|
||||
img_out_dir=out_dir,
|
||||
pred_out_dir=out_dir,
|
||||
save_mask=save_mask,
|
||||
mask_dir=mask_dir,
|
||||
**kwargs)
|
||||
|
||||
def visualize(self,
|
||||
inputs: list,
|
||||
preds: List[dict],
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
img_out_dir: str = '',
|
||||
opacity: float = 0.8) -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
||||
preds (Any): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
img_out_dir (str): Output directory of drawn images. Defaults: ''
|
||||
opacity (int, float): The transparency of segmentation mask.
|
||||
Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: Visualization results.
|
||||
"""
|
||||
if self.visualizer is None or (not show and img_out_dir == ''):
|
||||
return None
|
||||
|
||||
if getattr(self, 'visualizer') is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None')
|
||||
|
||||
self.visualizer.alpha = opacity
|
||||
|
||||
results = []
|
||||
|
||||
for single_input, pred in zip(inputs, preds):
|
||||
if isinstance(single_input, str):
|
||||
img_bytes = mmengine.fileio.get(single_input)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
img = img[:, :, ::-1]
|
||||
img_name = osp.basename(single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
img = single_input.copy()
|
||||
img_num = str(self.num_visualized_imgs).zfill(8)
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Unsupported input type:'
|
||||
f'{type(single_input)}')
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
|
||||
else None
|
||||
|
||||
self.visualizer.add_datasample(
|
||||
img_name,
|
||||
img,
|
||||
pred,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_gt=False,
|
||||
draw_pred=draw_pred,
|
||||
out_file=out_file)
|
||||
results.append(self.visualizer.get_image())
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results
|
||||
|
||||
def postprocess(self,
|
||||
preds: PredType,
|
||||
visualization: List[np.ndarray],
|
||||
return_datasample: bool = False,
|
||||
mask_dir: str = 'mask',
|
||||
save_mask: bool = True,
|
||||
pred_out_dir: str = '') -> dict:
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
||||
This method should be responsible for the following tasks:
|
||||
|
||||
1. Convert datasamples into a json-serializable dict if needed.
|
||||
2. Pack the predictions and visualization results and return them.
|
||||
3. Dump or log the predictions.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
visualization (np.ndarray): Visualized predictions.
|
||||
return_datasample (bool): Whether to return results as datasamples.
|
||||
Defaults to False.
|
||||
pred_out_dir: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
mask_dir (str): Sub directory of `pred_out_dir`, used to save pred
|
||||
mask file.
|
||||
save_mask (bool): Whether save pred mask as a file.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results with key ``predictions``
|
||||
and ``visualization``
|
||||
|
||||
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||
- ``predictions`` (dict or DataSample): Returned by
|
||||
:meth:`forward` and processed in :meth:`postprocess`.
|
||||
If ``return_datasample=False``, it usually should be a
|
||||
json-serializable dict containing only basic data elements such
|
||||
as strings and numbers.
|
||||
"""
|
||||
results_dict = {}
|
||||
|
||||
results_dict['predictions'] = preds
|
||||
results_dict['visualization'] = visualization
|
||||
|
||||
if pred_out_dir != '':
|
||||
mmengine.mkdir_or_exist(pred_out_dir)
|
||||
if save_mask:
|
||||
preds = [preds] if isinstance(preds, SegDataSample) else preds
|
||||
for pred in preds:
|
||||
mmcv.imwrite(
|
||||
pred.pred_sem_seg.numpy().data[0],
|
||||
osp.join(pred_out_dir, mask_dir,
|
||||
osp.basename(pred.metainfo['img_path'])))
|
||||
else:
|
||||
mmengine.dump(results_dict,
|
||||
osp.join(pred_out_dir, 'results.pkl'))
|
||||
|
||||
if return_datasample:
|
||||
return preds
|
||||
|
||||
return results_dict
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
||||
"""Initialize the test pipeline.
|
||||
|
||||
Return a pipeline to handle various input data, such as ``str``,
|
||||
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
||||
be implemented in subclasses.
|
||||
|
||||
The returned pipeline will be used to process a single data.
|
||||
It will be used in :meth:`preprocess` like this:
|
||||
|
||||
.. code-block:: python
|
||||
def preprocess(self, inputs, batch_size, **kwargs):
|
||||
...
|
||||
dataset = map(self.pipeline, dataset)
|
||||
...
|
||||
"""
|
||||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
# Loading annotations is also not applicable
|
||||
idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations')
|
||||
if idx != -1:
|
||||
del pipeline_cfg[idx]
|
||||
load_img_idx = self._get_transform_idx(pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
|
||||
if load_img_idx == -1:
|
||||
raise ValueError(
|
||||
'LoadImageFromFile is not found in the test pipeline')
|
||||
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader'
|
||||
return Compose(pipeline_cfg)
|
||||
|
||||
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
|
||||
"""Returns the index of the transform in a pipeline.
|
||||
|
||||
If the transform is not found, returns -1.
|
||||
"""
|
||||
for i, transform in enumerate(pipeline_cfg):
|
||||
if transform['type'] == name:
|
||||
return i
|
||||
return -1
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine.fileio as fileio
|
||||
@ -437,3 +437,59 @@ class LoadBiomedicalData(BaseTransform):
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class InferencerLoader(BaseTransform):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.from_file = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromFile', **kwargs))
|
||||
self.from_ndarray = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromNDArray', **kwargs))
|
||||
|
||||
def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
if isinstance(single_input, str):
|
||||
inputs = dict(img_path=single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
inputs = dict(img=single_input)
|
||||
elif isinstance(single_input, dict):
|
||||
inputs = single_input
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'img' in inputs:
|
||||
return self.from_ndarray(inputs)
|
||||
return self.from_file(inputs)
|
||||
|
@ -63,16 +63,7 @@ class SegLocalVisualizer(Visualizer):
|
||||
**kwargs):
|
||||
super().__init__(name, image, vis_backends, save_dir, **kwargs)
|
||||
self.alpha: float = alpha
|
||||
# Set default value. When calling
|
||||
# `SegLocalVisualizer().dataset_meta=xxx`,
|
||||
# it will override the default value.
|
||||
if dataset_name is None:
|
||||
dataset_name = 'cityscapes'
|
||||
classes = classes if classes else get_classes(dataset_name)
|
||||
palette = palette if palette else get_palette(dataset_name)
|
||||
assert len(classes) == len(
|
||||
palette), 'The length of classes should be equal to palette'
|
||||
self.dataset_meta: dict = {'classes': classes, 'palette': palette}
|
||||
self.set_dataset_meta(palette, classes, dataset_name)
|
||||
|
||||
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
|
||||
classes: Optional[Tuple[str]],
|
||||
@ -109,6 +100,21 @@ class SegLocalVisualizer(Visualizer):
|
||||
|
||||
return self.get_image()
|
||||
|
||||
def set_dataset_meta(self,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
dataset_name: Optional[str] = None) -> None:
|
||||
# Set default value. When calling
|
||||
# `SegLocalVisualizer().dataset_meta=xxx`,
|
||||
# it will override the default value.
|
||||
if dataset_name is None:
|
||||
dataset_name = 'cityscapes'
|
||||
classes = classes if classes else get_classes(dataset_name)
|
||||
palette = palette if palette else get_palette(dataset_name)
|
||||
assert len(classes) == len(
|
||||
palette), 'The length of classes should be equal to palette'
|
||||
self.dataset_meta: dict = {'classes': classes, 'palette': palette}
|
||||
|
||||
@master_only
|
||||
def add_datasample(
|
||||
self,
|
||||
@ -186,6 +192,6 @@ class SegLocalVisualizer(Visualizer):
|
||||
self.show(drawn_img, win_name=name, wait_time=wait_time)
|
||||
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(drawn_img, out_file)
|
||||
mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file)
|
||||
else:
|
||||
self.add_image(name, drawn_img, step)
|
||||
|
@ -1,4 +1,4 @@
|
||||
mmcls>=1.0.0rc0
|
||||
mmcv>=2.0.0rc4
|
||||
-e git+https://github.com/open-mmlab/mmdetection.git@dev-3.x#egg=mmdet
|
||||
mmengine>=0.4.0,<1.0.0
|
||||
mmengine>=0.5.0,<1.0.0
|
||||
|
115
tests/test_apis/test_inferencer.py
Normal file
115
tests/test_apis/test_inferencer.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine import ConfigDict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmseg.apis import MMSegInferencer
|
||||
from mmseg.models import EncoderDecoder
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleHead')
|
||||
class ExampleDecodeHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self, num_classes=19, out_channels=None):
|
||||
super().__init__(
|
||||
3, 3, num_classes=num_classes, out_channels=out_channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleBackbone')
|
||||
class ExampleBackbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
return [self.conv(x)]
|
||||
|
||||
|
||||
@MODELS.register_module(name='InferExampleModel')
|
||||
class ExampleModel(EncoderDecoder):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(img=torch.tensor([1]), img_metas=dict())
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
def test_inferencer():
|
||||
register_all_modules()
|
||||
test_dataset = ExampleDataset()
|
||||
data_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
visualizer = dict(
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
name='visualizer')
|
||||
|
||||
cfg_dict = dict(
|
||||
model=dict(
|
||||
type='InferExampleModel',
|
||||
data_preprocessor=dict(type='SegDataPreProcessor'),
|
||||
backbone=dict(type='InferExampleBackbone'),
|
||||
decode_head=dict(type='InferExampleHead'),
|
||||
test_cfg=dict(mode='whole')),
|
||||
visualizer=visualizer,
|
||||
test_dataloader=data_loader)
|
||||
cfg = ConfigDict(cfg_dict)
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
ckpt = model.state_dict()
|
||||
ckpt_filename = tempfile.mktemp()
|
||||
torch.save(ckpt, ckpt_filename)
|
||||
|
||||
# test initialization
|
||||
infer = MMSegInferencer(cfg, ckpt_filename)
|
||||
|
||||
# test forward
|
||||
img = np.random.randint(0, 256, (4, 4, 3))
|
||||
infer(img)
|
||||
|
||||
imgs = [img, img]
|
||||
infer(imgs)
|
||||
results = infer(imgs, out_dir=tempfile.gettempdir(), draw_pred=True)
|
||||
|
||||
# test results
|
||||
assert 'predictions' in results
|
||||
assert 'visualization' in results
|
||||
assert len(results['predictions']) == 2
|
||||
assert results['predictions'][0].seg_logits.data.shape == torch.Size(
|
||||
(19, 4, 4))
|
||||
assert results['predictions'][0].pred_sem_seg.shape == torch.Size((4, 4))
|
Loading…
x
Reference in New Issue
Block a user