mmpretrain/mmcls/apis/inference.py

80 lines
2.7 KiB
Python
Raw Normal View History

import warnings
2020-05-21 21:21:43 +08:00
import mmcv
import numpy as np
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
2020-05-21 21:21:43 +08:00
from mmcls.datasets.pipelines import Compose
from mmcls.models import build_classifier
def init_model(config, checkpoint=None, device='cuda:0'):
"""Initialize a classifier from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
Returns:
nn.Module: The constructed classifier.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
config.model.pretrained = None
model = build_classifier(config.model)
if checkpoint is not None:
map_loc = 'cpu' if device == 'cpu' else None
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets import ImageNet
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = ImageNet.CLASSES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_model(model, img):
"""Inference image(s) with the classifier.
Args:
model (nn.Module): The loaded classifier.
img (str/ndarray): The image filename.
Returns:
result (dict): The classification results that contains
`class_name`, `pred_label` and `pred_score`.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = Compose(cfg.data.test.pipeline)
# prepare data
data = dict(img_info=dict(filename=img), img_prefix=None)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
# forward the model
with torch.no_grad():
scores = model(return_loss=False, **data)
pred_score = np.max(scores, axis=1)[0]
pred_label = np.argmax(scores, axis=1)[0]
result = {'pred_label': pred_label, 'pred_score': pred_score}
result['class_name'] = model.CLASSES[result['pred_label']]
return result