mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
support load v1/v2 ckpt (#1868)
This commit is contained in:
parent
167f94a70b
commit
bfe0fbe04d
@ -1,45 +1,85 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Union
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmengine import Config
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.models import BaseSegmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList
|
||||
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
|
||||
def init_model(config, checkpoint=None, device='cuda:0'):
|
||||
def init_model(config: Union[str, Path, Config],
|
||||
checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0',
|
||||
cfg_options: Optional[dict] = None):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config (str or :obj:`mmcv.Config`): Config file path or the config
|
||||
object.
|
||||
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
||||
:obj:`Path`, or the config object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
||||
Use 'cpu' for loading model on CPU.
|
||||
cfg_options (dict, optional): Options to override some settings in
|
||||
the used config.
|
||||
Returns:
|
||||
nn.Module: The constructed segmentor.
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, mmcv.Config):
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
'but got {}'.format(type(config)))
|
||||
if cfg_options is not None:
|
||||
config.merge_from_dict(cfg_options)
|
||||
elif 'init_cfg' in config.model.backbone:
|
||||
config.model.backbone.init_cfg = None
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
model.PALETTE = checkpoint['meta']['PALETTE']
|
||||
|
||||
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
# mmseg 1.x
|
||||
model.dataset_meta = dataset_meta
|
||||
elif 'CLASSES' in checkpoint.get('meta', {}):
|
||||
# < mmseg 1.x
|
||||
classes = checkpoint['meta']['CLASSES']
|
||||
palette = checkpoint['meta']['PALETTE']
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, classes and palette will be'
|
||||
'set according to num_classes ')
|
||||
num_classes = model.decode_head.num_classes
|
||||
dataset_name = None
|
||||
for name in dataset_aliases.keys():
|
||||
if len(get_classes(name)) == num_classes:
|
||||
dataset_name = name
|
||||
break
|
||||
if dataset_name is None:
|
||||
warnings.warn(
|
||||
'No suitable dataset found, use Cityscapes by default')
|
||||
dataset_name = 'cityscapes'
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes(dataset_name),
|
||||
'palette': get_palette(dataset_name)
|
||||
}
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
Loading…
x
Reference in New Issue
Block a user