# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from copy import deepcopy from typing import Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np import torch from mmengine import Config from mmengine.model import BaseDataPreprocessor from mmengine.registry import Registry from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase from mmdeploy.utils import Codebase, Task, get_root_logger from mmdeploy.utils.config_utils import get_input_shape MMCLS_TASK = Registry('mmcls_tasks') @CODEBASE.register_module(Codebase.MMCLS.value) class MMClassification(MMCodebase): """mmclassification codebase class.""" task_registry = MMCLS_TASK @classmethod def register_deploy_modules(cls): """register all rewriters for mmcls.""" import mmdeploy.codebase.mmcls.models # noqa: F401 @classmethod def register_all_modules(cls): """register all related modules and rewriters for mmcls.""" from mmcls.utils.setup_env import register_all_modules cls.register_deploy_modules() register_all_modules(True) def process_model_config(model_cfg: Config, imgs: Union[str, np.ndarray], input_shape: Optional[Sequence[int]] = None): """Process the model config. Args: model_cfg (Config): The model config. imgs (str | np.ndarray): Input image(s), accepted data type are `str`, `np.ndarray`. input_shape (list[int]): A list of two integer in (width, height) format specifying input shape. Default: None. Returns: Config: the model config after processing. """ cfg = model_cfg.deepcopy() if not isinstance(imgs, (list, tuple)): imgs = [imgs] if isinstance(imgs[0], str): if cfg.test_pipeline[0]['type'] != 'LoadImageFromFile': cfg.test_pipeline.insert(0, dict(type='LoadImageFromFile')) else: if cfg.test_pipeline[0]['type'] == 'LoadImageFromFile': cfg.test_pipeline.pop(0) # check whether input_shape is valid if input_shape is not None: if 'crop_size' in cfg.test_pipeline[2]: crop_size = cfg.test_pipeline[2]['crop_size'] if tuple(input_shape) != (crop_size, crop_size): logger = get_root_logger() logger.warning( f'`input shape` should be equal to `crop_size`: {crop_size},\ but given: {input_shape}') return cfg def _get_dataset_metainfo(model_cfg: Config): """Get metainfo of dataset. Args: model_cfg Config: Input model Config object. Returns: list[str]: A list of string specifying names of different class. """ from mmcls import datasets # noqa from mmcls.registry import DATASETS module_dict = DATASETS.module_dict for dataloader_name in [ 'test_dataloader', 'val_dataloader', 'train_dataloader' ]: if dataloader_name not in model_cfg: continue dataloader_cfg = model_cfg[dataloader_name] dataset_cfg = dataloader_cfg.dataset dataset_cls = module_dict.get(dataset_cfg.type, None) if dataset_cls is None: continue if hasattr(dataset_cls, '_load_metainfo') and isinstance( dataset_cls._load_metainfo, Callable): meta = dataset_cls._load_metainfo( dataset_cfg.get('metainfo', None)) if meta is not None: return meta if hasattr(dataset_cls, 'METAINFO'): return dataset_cls.METAINFO return None @MMCLS_TASK.register_module(Task.CLASSIFICATION.value) class Classification(BaseTask): """Classification task class. Args: model_cfg (Config): Original PyTorch model config file. deploy_cfg (Config): Deployment config file or loaded Config object. device (str): A string represents device type. """ def __init__(self, model_cfg: Config, deploy_cfg: Config, device: str): super(Classification, self).__init__(model_cfg, deploy_cfg, device) def build_data_preprocessor(self): """Build data preprocessor. Returns: nn.Module: A model build with mmcls data_preprocessor. """ model_cfg = deepcopy(self.model_cfg) data_preprocessor = deepcopy(model_cfg.get('preprocess_cfg', {})) data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') from mmengine.registry import MODELS data_preprocessor = MODELS.build(data_preprocessor) return data_preprocessor def build_backend_model(self, model_files: Sequence[str] = None, **kwargs) -> torch.nn.Module: """Initialize backend model. Args: model_files (Sequence[str]): Input model files. Returns: nn.Module: An initialized backend model. """ from .classification_model import build_classification_model data_preprocessor = self.model_cfg.data_preprocessor data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') model = build_classification_model( model_files, self.model_cfg, self.deploy_cfg, device=self.device, data_preprocessor=data_preprocessor) model = model.to(self.device) return model.eval() def create_input( self, imgs: Union[str, np.ndarray], input_shape: Optional[Sequence[int]] = None, data_preprocessor: Optional[BaseDataPreprocessor] = None ) -> Tuple[Dict, torch.Tensor]: """Create input for classifier. Args: imgs (Union[str, np.ndarray, Sequence]): Input image(s), accepted data type are `str`, `np.ndarray`, Sequence. input_shape (list[int]): A list of two integer in (width, height) format specifying input shape. Default: None. data_preprocessor (BaseDataPreprocessor): The data preprocessor of the model. Default to `None`. Returns: tuple: (data, img), meta information for the input image and input. """ from mmengine.dataset import Compose, pseudo_collate if not isinstance(imgs, (list, tuple)): imgs = [imgs] assert 'test_pipeline' in self.model_cfg, \ f'test_pipeline not found in {self.model_cfg}.' model_cfg = process_model_config(self.model_cfg, imgs, input_shape) pipeline = deepcopy(model_cfg.test_pipeline) move_pipeline = [] while pipeline[-1]['type'] != 'PackClsInputs': sub_pipeline = pipeline.pop(-1) move_pipeline = [sub_pipeline] + move_pipeline pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:] pipeline = Compose(pipeline) data = [] for img in imgs: # prepare data if isinstance(img, str): data_ = dict(img_path=img) else: data_ = dict(img=img) # build the data pipeline data_ = pipeline(data_) data.append(data_) data = pseudo_collate(data) if data_preprocessor is not None: data = data_preprocessor(data, False) return data, data['inputs'] else: return data, BaseTask.get_tensor_from_input(data) def get_visualizer(self, name: str, save_dir: str): """Get mmcls visualizer. Args: name (str): Name of visualizer. save_dir (str): Directory to save drawn results. Returns: ClsVisualizer: Instance of mmcls visualizer. """ visualizer = super().get_visualizer(name, save_dir) metainfo = _get_dataset_metainfo(self.model_cfg) if metainfo is not None: visualizer.dataset_meta = metainfo return visualizer def visualize(self, image: Union[str, np.ndarray], result: list, output_file: str, window_name: str = '', show_result: bool = False, **kwargs): """Visualize predictions of a model. Args: model (nn.Module): Input model. image (str | np.ndarray): Input image to draw predictions on. result (list): A list of predictions. output_file (str): Output file to save drawn image. window_name (str): The name of visualization window. Defaults to an empty string. show_result (bool): Whether to show result in windows, defaults to `False`. """ save_dir, save_name = osp.split(output_file) visualizer = self.get_visualizer(window_name, save_dir) name = osp.splitext(save_name)[0] image = mmcv.imread(image, channel_order='rgb') visualizer.add_datasample( name, image, result, show=show_result, out_file=output_file) @staticmethod def get_partition_cfg(partition_type: str) -> Dict: """Get a certain partition config. Args: partition_type (str): A string specifying partition type. Returns: dict: A dictionary of partition config. """ raise NotImplementedError('Not supported yet.') def get_preprocess(self, *args, **kwargs) -> Dict: """Get the preprocess information for SDK. Return: dict: Composed of the preprocess information. """ input_shape = get_input_shape(self.deploy_cfg) cfg = process_model_config(self.model_cfg, '', input_shape) pipeline = cfg.test_pipeline meta_keys = [ 'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg', 'valid_ratio' ] transforms = [ item for item in pipeline if 'Random' not in item['type'] ] move_pipeline = [] import re while re.search('Pack[a-z | A-Z]+Inputs', transforms[-1]['type']) is None: sub_pipeline = transforms.pop(-1) move_pipeline = [sub_pipeline] + move_pipeline transforms = transforms[:-1] + move_pipeline + transforms[-1:] for i, transform in enumerate(transforms): if transform['type'] == 'PackClsInputs': meta_keys += transform[ 'meta_keys'] if 'meta_keys' in transform else [] transform['meta_keys'] = list(set(meta_keys)) transform['keys'] = ['img'] transforms[i]['type'] = 'Collect' if transform['type'] == 'Resize': transforms[i]['size'] = transforms[i].pop('scale') if transform['type'] == 'ResizeEdge': transforms[i] = dict( type='Resize', keep_ratio=True, size=(transform['scale'], -1)) data_preprocessor = self.model_cfg.data_preprocessor transforms.insert(-1, dict(type='ImageToTensor', keys=['img'])) transforms.insert( -2, dict( type='Normalize', to_rgb=data_preprocessor.get('to_rgb', False), mean=data_preprocessor.get('mean', [0, 0, 0]), std=data_preprocessor.get('std', [1, 1, 1]))) return transforms def get_postprocess(self, *args, **kwargs) -> Dict: """Get the postprocess information for SDK. Return: dict: Composed of the postprocess information. """ postprocess = self.model_cfg.model.head if 'topk' not in postprocess: topk = (1, ) logger = get_root_logger() logger.warning('no topk in postprocess config, using default \ topk value.') else: topk = postprocess.topk postprocess.topk = max(topk) return dict(type=postprocess.pop('type'), params=postprocess) def get_model_name(self, *args, **kwargs) -> str: """Get the model name. Return: str: the name of the model. """ assert 'backbone' in self.model_cfg.model, 'backbone not in model ' 'config' assert 'type' in self.model_cfg.model.backbone, 'backbone contains ' 'no type' name = self.model_cfg.model.backbone.type.lower() return name