2021-08-17 19:52:42 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2022-12-19 13:01:11 +08:00
|
|
|
from typing import TYPE_CHECKING, Union
|
2020-05-21 21:21:43 +08:00
|
|
|
|
2022-12-06 17:00:22 +08:00
|
|
|
import numpy as np
|
2020-09-30 19:00:20 +08:00
|
|
|
import torch
|
2020-05-21 21:21:43 +08:00
|
|
|
|
2022-12-19 13:01:11 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from mmengine.model import BaseModel
|
2020-09-30 19:00:20 +08:00
|
|
|
|
|
|
|
|
2022-12-19 13:01:11 +08:00
|
|
|
def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]):
|
2020-09-30 19:00:20 +08:00
|
|
|
"""Inference image(s) with the classifier.
|
|
|
|
|
|
|
|
Args:
|
2022-07-22 10:59:08 +08:00
|
|
|
model (BaseClassifier): The loaded classifier.
|
2020-11-19 18:58:25 +08:00
|
|
|
img (str/ndarray): The image filename or loaded image.
|
2020-09-30 19:00:20 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
result (dict): The classification results that contains
|
|
|
|
`class_name`, `pred_label` and `pred_score`.
|
|
|
|
"""
|
2022-12-19 13:01:11 +08:00
|
|
|
from mmengine.dataset import Compose, default_collate
|
|
|
|
from mmengine.registry import DefaultScope
|
|
|
|
|
2020-09-30 19:00:20 +08:00
|
|
|
cfg = model.cfg
|
|
|
|
# build the data pipeline
|
2022-07-22 10:59:08 +08:00
|
|
|
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
2020-11-19 18:58:25 +08:00
|
|
|
if isinstance(img, str):
|
2022-07-22 10:59:08 +08:00
|
|
|
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
|
|
|
|
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
|
|
|
|
data = dict(img_path=img)
|
2020-11-19 18:58:25 +08:00
|
|
|
else:
|
2022-07-22 10:59:08 +08:00
|
|
|
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
|
|
|
test_pipeline_cfg.pop(0)
|
2020-11-19 18:58:25 +08:00
|
|
|
data = dict(img=img)
|
2022-12-06 17:00:22 +08:00
|
|
|
with DefaultScope.overwrite_default_scope('mmcls'):
|
|
|
|
test_pipeline = Compose(test_pipeline_cfg)
|
2020-09-30 19:00:20 +08:00
|
|
|
data = test_pipeline(data)
|
2022-12-06 17:00:22 +08:00
|
|
|
data = default_collate([data])
|
2020-09-30 19:00:20 +08:00
|
|
|
|
|
|
|
# forward the model
|
|
|
|
with torch.no_grad():
|
2022-07-22 10:59:08 +08:00
|
|
|
prediction = model.val_step(data)[0].pred_label
|
2022-10-08 16:21:34 +09:00
|
|
|
pred_scores = prediction.score.tolist()
|
2022-07-22 10:59:08 +08:00
|
|
|
pred_score = torch.max(prediction.score).item()
|
|
|
|
pred_label = torch.argmax(prediction.score).item()
|
2022-10-08 16:21:34 +09:00
|
|
|
result = {
|
|
|
|
'pred_label': pred_label,
|
|
|
|
'pred_score': float(pred_score),
|
|
|
|
'pred_scores': pred_scores
|
|
|
|
}
|
2022-07-22 10:59:08 +08:00
|
|
|
if hasattr(model, 'CLASSES'):
|
|
|
|
result['pred_class'] = model.CLASSES[result['pred_label']]
|
2020-09-30 19:00:20 +08:00
|
|
|
return result
|