From bfe0fbe04da4ccc36be5acf4a6e4aef598f04731 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Fri, 5 Aug 2022 20:18:55 +0800 Subject: [PATCH] support load v1/v2 ckpt (#1868) --- mmseg/apis/inference.py | 62 +++++++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 11 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index ac88f295b..5eff2d228 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -1,45 +1,85 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Sequence, Union +import warnings +from pathlib import Path +from typing import Optional, Sequence, Union import mmcv import numpy as np import torch -from mmcv.runner import load_checkpoint from mmengine import Config from mmengine.dataset import Compose +from mmengine.runner import load_checkpoint +from mmseg.data import SegDataSample from mmseg.models import BaseSegmentor from mmseg.registry import MODELS -from mmseg.structures import SegDataSample -from mmseg.utils import SampleList +from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette from mmseg.visualization import SegLocalVisualizer -def init_model(config, checkpoint=None, device='cuda:0'): +def init_model(config: Union[str, Path, Config], + checkpoint: Optional[str] = None, + device: str = 'cuda:0', + cfg_options: Optional[dict] = None): """Initialize a segmentor from config file. Args: - config (str or :obj:`mmcv.Config`): Config file path or the config - object. + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. device (str, optional) CPU/CUDA device option. Default 'cuda:0'. Use 'cpu' for loading model on CPU. + cfg_options (dict, optional): Options to override some settings in + the used config. Returns: nn.Module: The constructed segmentor. """ - if isinstance(config, str): + if isinstance(config, (str, Path)): config = Config.fromfile(config) - elif not isinstance(config, mmcv.Config): + elif not isinstance(config, Config): raise TypeError('config must be a filename or Config object, ' 'but got {}'.format(type(config))) + if cfg_options is not None: + config.merge_from_dict(cfg_options) + elif 'init_cfg' in config.model.backbone: + config.model.backbone.init_cfg = None config.model.pretrained = None config.model.train_cfg = None model = MODELS.build(config.model) if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') - model.CLASSES = checkpoint['meta']['CLASSES'] - model.PALETTE = checkpoint['meta']['PALETTE'] + + dataset_meta = checkpoint['meta'].get('dataset_meta', None) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint.get('meta', {}): + # mmseg 1.x + model.dataset_meta = dataset_meta + elif 'CLASSES' in checkpoint.get('meta', {}): + # < mmseg 1.x + classes = checkpoint['meta']['CLASSES'] + palette = checkpoint['meta']['PALETTE'] + model.dataset_meta = {'classes': classes, 'palette': palette} + else: + warnings.simplefilter('once') + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, classes and palette will be' + 'set according to num_classes ') + num_classes = model.decode_head.num_classes + dataset_name = None + for name in dataset_aliases.keys(): + if len(get_classes(name)) == num_classes: + dataset_name = name + break + if dataset_name is None: + warnings.warn( + 'No suitable dataset found, use Cityscapes by default') + dataset_name = 'cityscapes' + model.dataset_meta = { + 'classes': get_classes(dataset_name), + 'palette': get_palette(dataset_name) + } model.cfg = config # save the config in the model for convenience model.to(device) model.eval()