diff --git a/docs/en/user_guides/inference.md b/docs/en/user_guides/inference.md index a1c970c1e..af66846a5 100644 --- a/docs/en/user_guides/inference.md +++ b/docs/en/user_guides/inference.md @@ -38,10 +38,10 @@ model = init_model(config_path, checkpoint_path, device="cpu") # device can be result = inference_model(model, img_path) ``` -`result` is a dictionary containing `pred_label`, `pred_score` and `pred_score`, the result is as follows: +`result` is a dictionary containing `pred_label`, `pred_score`, `pred_scores` and `pred_class`, the result is as follows: ```text -{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake"} +{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]} ``` An image demo can be found in [demo/image_demo.py](https://github.com/open-mmlab/mmclassification/blob/1.x/demo/image_demo.py). diff --git a/docs/zh_CN/user_guides/inference.md b/docs/zh_CN/user_guides/inference.md index b1f0afca1..35c745618 100644 --- a/docs/zh_CN/user_guides/inference.md +++ b/docs/zh_CN/user_guides/inference.md @@ -39,10 +39,10 @@ result = inference_model(model, img_path) print(result) ``` -`result` 为一个包含了 `pred_label`, `pred_score` 和 `pred_score`的字典,结果如下: +`result` 为一个包含了 `pred_label`, `pred_score`, `pred_scores` 和 `pred_class`的字典,结果如下: ```text -{"pred_label":65,"pred_score":0.6649366617202759,"pred_score":"sea snake"} +{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]} ``` 演示可以在 [demo/image_demo.py](https://github.com/open-mmlab/mmclassification/blob/1.x/demo/image_demo.py) 中找到。 diff --git a/mmcls/apis/inference.py b/mmcls/apis/inference.py index c2a294294..80ddf6f37 100644 --- a/mmcls/apis/inference.py +++ b/mmcls/apis/inference.py @@ -80,9 +80,14 @@ def inference_model(model, img): # forward the model with torch.no_grad(): prediction = model.val_step(data)[0].pred_label + pred_scores = prediction.score.tolist() pred_score = torch.max(prediction.score).item() pred_label = torch.argmax(prediction.score).item() - result = {'pred_label': pred_label, 'pred_score': float(pred_score)} + result = { + 'pred_label': pred_label, + 'pred_score': float(pred_score), + 'pred_scores': pred_scores + } if hasattr(model, 'CLASSES'): result['pred_class'] = model.CLASSES[result['pred_label']] return result