mmdeploy/mmdeploy/codebase/mmcls/deploy/classification.py

349 lines
12 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from copy import deepcopy
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
import torch
from mmengine import Config
from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task, get_root_logger
from mmdeploy.utils.config_utils import get_input_shape
MMCLS_TASK = Registry('mmcls_tasks')
@CODEBASE.register_module(Codebase.MMCLS.value)
class MMClassification(MMCodebase):
"""mmclassification codebase class."""
task_registry = MMCLS_TASK
@classmethod
def register_deploy_modules(cls):
"""register all rewriters for mmcls."""
import mmdeploy.codebase.mmcls.models # noqa: F401
@classmethod
def register_all_modules(cls):
"""register all related modules and rewriters for mmcls."""
from mmcls.utils.setup_env import register_all_modules
cls.register_deploy_modules()
register_all_modules(True)
def process_model_config(model_cfg: Config,
imgs: Union[str, np.ndarray],
input_shape: Optional[Sequence[int]] = None):
"""Process the model config.
Args:
model_cfg (Config): The model config.
imgs (str | np.ndarray): Input image(s), accepted data type are `str`,
`np.ndarray`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Default: None.
Returns:
Config: the model config after processing.
"""
cfg = model_cfg.deepcopy()
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
if isinstance(imgs[0], str):
if cfg.test_pipeline[0]['type'] != 'LoadImageFromFile':
cfg.test_pipeline.insert(0, dict(type='LoadImageFromFile'))
else:
if cfg.test_pipeline[0]['type'] == 'LoadImageFromFile':
cfg.test_pipeline.pop(0)
# check whether input_shape is valid
if input_shape is not None:
if 'crop_size' in cfg.test_pipeline[2]:
crop_size = cfg.test_pipeline[2]['crop_size']
if tuple(input_shape) != (crop_size, crop_size):
logger = get_root_logger()
logger.warning(
f'`input shape` should be equal to `crop_size`: {crop_size},\
but given: {input_shape}')
return cfg
def _get_dataset_metainfo(model_cfg: Config):
"""Get metainfo of dataset.
Args:
model_cfg Config: Input model Config object.
Returns:
list[str]: A list of string specifying names of different class.
"""
from mmcls import datasets # noqa
from mmcls.registry import DATASETS
module_dict = DATASETS.module_dict
for dataloader_name in [
'test_dataloader', 'val_dataloader', 'train_dataloader'
]:
if dataloader_name not in model_cfg:
continue
dataloader_cfg = model_cfg[dataloader_name]
dataset_cfg = dataloader_cfg.dataset
dataset_cls = module_dict.get(dataset_cfg.type, None)
if dataset_cls is None:
continue
if hasattr(dataset_cls, '_load_metainfo') and isinstance(
dataset_cls._load_metainfo, Callable):
meta = dataset_cls._load_metainfo(
dataset_cfg.get('metainfo', None))
if meta is not None:
return meta
if hasattr(dataset_cls, 'METAINFO'):
return dataset_cls.METAINFO
return None
@MMCLS_TASK.register_module(Task.CLASSIFICATION.value)
class Classification(BaseTask):
"""Classification task class.
Args:
model_cfg (Config): Original PyTorch model config file.
deploy_cfg (Config): Deployment config file or loaded Config
object.
device (str): A string represents device type.
"""
def __init__(self, model_cfg: Config, deploy_cfg: Config, device: str):
super(Classification, self).__init__(model_cfg, deploy_cfg, device)
def build_data_preprocessor(self):
"""Build data preprocessor.
Returns:
nn.Module: A model build with mmcls data_preprocessor.
"""
model_cfg = deepcopy(self.model_cfg)
data_preprocessor = deepcopy(model_cfg.get('preprocess_cfg', {}))
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
from mmengine.registry import MODELS
data_preprocessor = MODELS.build(data_preprocessor)
return data_preprocessor
def build_backend_model(self,
model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module:
"""Initialize backend model.
Args:
model_files (Sequence[str]): Input model files.
Returns:
nn.Module: An initialized backend model.
"""
from .classification_model import build_classification_model
data_preprocessor = self.model_cfg.data_preprocessor
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
model = build_classification_model(
model_files,
self.model_cfg,
self.deploy_cfg,
device=self.device,
data_preprocessor=data_preprocessor)
model = model.to(self.device)
return model.eval()
def create_input(
self,
imgs: Union[str, np.ndarray],
input_shape: Optional[Sequence[int]] = None,
data_preprocessor: Optional[BaseDataPreprocessor] = None
) -> Tuple[Dict, torch.Tensor]:
"""Create input for classifier.
Args:
imgs (Union[str, np.ndarray, Sequence]): Input image(s),
accepted data type are `str`, `np.ndarray`, Sequence.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Default: None.
data_preprocessor (BaseDataPreprocessor): The data preprocessor
of the model. Default to `None`.
Returns:
tuple: (data, img), meta information for the input image and input.
"""
from mmengine.dataset import Compose, pseudo_collate
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
assert 'test_pipeline' in self.model_cfg, \
f'test_pipeline not found in {self.model_cfg}.'
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
pipeline = deepcopy(model_cfg.test_pipeline)
move_pipeline = []
while pipeline[-1]['type'] != 'PackClsInputs':
sub_pipeline = pipeline.pop(-1)
move_pipeline = [sub_pipeline] + move_pipeline
pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:]
pipeline = Compose(pipeline)
data = []
for img in imgs:
# prepare data
if isinstance(img, str):
data_ = dict(img_path=img)
else:
data_ = dict(img=img)
# build the data pipeline
data_ = pipeline(data_)
data.append(data_)
data = pseudo_collate(data)
if data_preprocessor is not None:
data = data_preprocessor(data, False)
return data, data['inputs']
else:
return data, BaseTask.get_tensor_from_input(data)
def get_visualizer(self, name: str, save_dir: str):
"""Get mmcls visualizer.
Args:
name (str): Name of visualizer.
save_dir (str): Directory to save drawn results.
Returns:
ClsVisualizer: Instance of mmcls visualizer.
"""
visualizer = super().get_visualizer(name, save_dir)
metainfo = _get_dataset_metainfo(self.model_cfg)
if metainfo is not None:
visualizer.dataset_meta = metainfo
return visualizer
def visualize(self,
image: Union[str, np.ndarray],
result: list,
output_file: str,
window_name: str = '',
show_result: bool = False,
**kwargs):
"""Visualize predictions of a model.
Args:
model (nn.Module): Input model.
image (str | np.ndarray): Input image to draw predictions on.
result (list): A list of predictions.
output_file (str): Output file to save drawn image.
window_name (str): The name of visualization window. Defaults to
an empty string.
show_result (bool): Whether to show result in windows, defaults
to `False`.
"""
save_dir, save_name = osp.split(output_file)
visualizer = self.get_visualizer(window_name, save_dir)
name = osp.splitext(save_name)[0]
image = mmcv.imread(image, channel_order='rgb')
visualizer.add_datasample(
name, image, result, show=show_result, out_file=output_file)
@staticmethod
def get_partition_cfg(partition_type: str) -> Dict:
"""Get a certain partition config.
Args:
partition_type (str): A string specifying partition type.
Returns:
dict: A dictionary of partition config.
"""
raise NotImplementedError('Not supported yet.')
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
dict: Composed of the preprocess information.
"""
input_shape = get_input_shape(self.deploy_cfg)
cfg = process_model_config(self.model_cfg, '', input_shape)
pipeline = cfg.test_pipeline
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
'valid_ratio'
]
transforms = [
item for item in pipeline if 'Random' not in item['type']
]
move_pipeline = []
import re
while re.search('Pack[a-z | A-Z]+Inputs',
transforms[-1]['type']) is None:
sub_pipeline = transforms.pop(-1)
move_pipeline = [sub_pipeline] + move_pipeline
transforms = transforms[:-1] + move_pipeline + transforms[-1:]
for i, transform in enumerate(transforms):
if transform['type'] == 'PackClsInputs':
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i].pop('scale')
if transform['type'] == 'ResizeEdge':
transforms[i] = dict(
type='Resize',
keep_ratio=True,
size=(transform['scale'], -1))
data_preprocessor = self.model_cfg.data_preprocessor
transforms.insert(-1, dict(type='ImageToTensor', keys=['img']))
transforms.insert(
-2,
dict(
type='Normalize',
to_rgb=data_preprocessor.get('to_rgb', False),
mean=data_preprocessor.get('mean', [0, 0, 0]),
std=data_preprocessor.get('std', [1, 1, 1])))
return transforms
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
dict: Composed of the postprocess information.
"""
postprocess = self.model_cfg.model.head
if 'topk' not in postprocess:
topk = (1, )
logger = get_root_logger()
logger.warning('no topk in postprocess config, using default \
topk value.')
else:
topk = postprocess.topk
postprocess.topk = max(topk)
return dict(type=postprocess.pop('type'), params=postprocess)
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:
str: the name of the model.
"""
assert 'backbone' in self.model_cfg.model, 'backbone not in model '
'config'
assert 'type' in self.model_cfg.model.backbone, 'backbone contains '
'no type'
name = self.model_cfg.model.backbone.type.lower()
return name