mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Enhancement] Get scores from inference api. (#1070)
This commit is contained in:
parent
ae37d7fd27
commit
a1642e42da
@ -38,10 +38,10 @@ model = init_model(config_path, checkpoint_path, device="cpu") # device can be
|
||||
result = inference_model(model, img_path)
|
||||
```
|
||||
|
||||
`result` is a dictionary containing `pred_label`, `pred_score` and `pred_score`, the result is as follows:
|
||||
`result` is a dictionary containing `pred_label`, `pred_score`, `pred_scores` and `pred_class`, the result is as follows:
|
||||
|
||||
```text
|
||||
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake"}
|
||||
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]}
|
||||
```
|
||||
|
||||
An image demo can be found in [demo/image_demo.py](https://github.com/open-mmlab/mmclassification/blob/1.x/demo/image_demo.py).
|
||||
|
@ -39,10 +39,10 @@ result = inference_model(model, img_path)
|
||||
print(result)
|
||||
```
|
||||
|
||||
`result` 为一个包含了 `pred_label`, `pred_score` 和 `pred_score`的字典,结果如下:
|
||||
`result` 为一个包含了 `pred_label`, `pred_score`, `pred_scores` 和 `pred_class`的字典,结果如下:
|
||||
|
||||
```text
|
||||
{"pred_label":65,"pred_score":0.6649366617202759,"pred_score":"sea snake"}
|
||||
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]}
|
||||
```
|
||||
|
||||
演示可以在 [demo/image_demo.py](https://github.com/open-mmlab/mmclassification/blob/1.x/demo/image_demo.py) 中找到。
|
||||
|
@ -80,9 +80,14 @@ def inference_model(model, img):
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
prediction = model.val_step(data)[0].pred_label
|
||||
pred_scores = prediction.score.tolist()
|
||||
pred_score = torch.max(prediction.score).item()
|
||||
pred_label = torch.argmax(prediction.score).item()
|
||||
result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
|
||||
result = {
|
||||
'pred_label': pred_label,
|
||||
'pred_score': float(pred_score),
|
||||
'pred_scores': pred_scores
|
||||
}
|
||||
if hasattr(model, 'CLASSES'):
|
||||
result['pred_class'] = model.CLASSES[result['pred_label']]
|
||||
return result
|
||||
|
Loading…
x
Reference in New Issue
Block a user