mmpretrain/mmcls/apis/inference.py

90 lines
3.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
from mmengine.config import Config
from mmengine.data import pseudo_collate
from mmengine.runner import load_checkpoint
from mmcls.datasets.pipelines import Compose
from mmcls.models import build_classifier
def init_model(config, checkpoint=None, device='cuda:0', options=None):
"""Initialize a classifier from config file.
Args:
config (str or :obj:`mmengine.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
options (dict): Options to override some settings in the used config.
Returns:
nn.Module: The constructed classifier.
"""
if isinstance(config, str):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if options is not None:
config.merge_from_dict(options)
config.model.pretrained = None
config.model.setdefault('data_preprocessor',
config.get('preprocess_cfg', None))
model = build_classifier(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets import ImageNet
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = ImageNet.CLASSES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_model(model, img):
"""Inference image(s) with the classifier.
Args:
model (BaseClassifier): The loaded classifier.
img (str/ndarray): The image filename or loaded image.
Returns:
result (dict): The classification results that contains
`class_name`, `pred_label` and `pred_score`.
"""
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if isinstance(img, str):
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
data = dict(img_path=img)
else:
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = pseudo_collate([data])
# forward the model
with torch.no_grad():
prediction = model.val_step(data)[0].pred_label
pred_score = torch.max(prediction.score).item()
pred_label = torch.argmax(prediction.score).item()
result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
if hasattr(model, 'CLASSES'):
result['pred_class'] = model.CLASSES[result['pred_label']]
return result