[Enhancement] Get scores from inference api. (#1070)

This commit is contained in:
takuoko 2022-10-08 16:21:34 +09:00 committed by GitHub
parent ae37d7fd27
commit a1642e42da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 5 deletions

View File

@ -38,10 +38,10 @@ model = init_model(config_path, checkpoint_path, device="cpu") # device can be
result = inference_model(model, img_path)
```
`result` is a dictionary containing `pred_label`, `pred_score` and `pred_score`, the result is as follows:
`result` is a dictionary containing `pred_label`, `pred_score`, `pred_scores` and `pred_class`, the result is as follows:
```text
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake"}
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]}
```
An image demo can be found in [demo/image_demo.py](https://github.com/open-mmlab/mmclassification/blob/1.x/demo/image_demo.py).

View File

@ -39,10 +39,10 @@ result = inference_model(model, img_path)
print(result)
```
`result` 为一个包含了 `pred_label`, `pred_score``pred_score`的字典,结果如下:
`result` 为一个包含了 `pred_label`, `pred_score`, `pred_scores``pred_class`的字典,结果如下:
```text
{"pred_label":65,"pred_score":0.6649366617202759,"pred_score":"sea snake"}
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]}
```
演示可以在 [demo/image_demo.py](https://github.com/open-mmlab/mmclassification/blob/1.x/demo/image_demo.py) 中找到。

View File

@ -80,9 +80,14 @@ def inference_model(model, img):
# forward the model
with torch.no_grad():
prediction = model.val_step(data)[0].pred_label
pred_scores = prediction.score.tolist()
pred_score = torch.max(prediction.score).item()
pred_label = torch.argmax(prediction.score).item()
result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
result = {
'pred_label': pred_label,
'pred_score': float(pred_score),
'pred_scores': pred_scores
}
if hasattr(model, 'CLASSES'):
result['pred_class'] = model.CLASSES[result['pred_label']]
return result