Add softmax to cls model (#1573)
* Add softmax to cls model * fix cls ci * multihead * update classification_model.pypull/1597/head
parent
baa86aa4a5
commit
deaefacea1
|
@ -333,8 +333,8 @@ class Classification(BaseTask):
|
|||
if 'topk' not in postprocess:
|
||||
topk = (1, )
|
||||
logger = get_root_logger()
|
||||
logger.warning('no topk in postprocess config, using default \
|
||||
topk value.')
|
||||
logger.warning('no topk in postprocess config, using default '
|
||||
'topk value.')
|
||||
else:
|
||||
topk = postprocess.topk
|
||||
postprocess.topk = max(topk)
|
||||
|
|
|
@ -35,6 +35,7 @@ class End2EndModel(BaseBackendModel):
|
|||
backend: Backend,
|
||||
backend_files: Sequence[str],
|
||||
device: str,
|
||||
model_cfg: Union[str, Config] = None,
|
||||
deploy_cfg: Union[str, Config] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
**kwargs):
|
||||
|
@ -46,8 +47,18 @@ class End2EndModel(BaseBackendModel):
|
|||
backend_files=backend_files,
|
||||
device=device,
|
||||
**kwargs)
|
||||
self.model_cfg = model_cfg
|
||||
self.head = None
|
||||
if model_cfg is not None:
|
||||
self.head = self._get_head()
|
||||
self.device = device
|
||||
|
||||
def _get_head(self):
|
||||
from mmcls.models import build_head
|
||||
head_config = self.model_cfg['model']['head']
|
||||
head = build_head(head_config)
|
||||
return head
|
||||
|
||||
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
|
||||
device: str, **kwargs):
|
||||
output_names = self.output_names
|
||||
|
@ -84,11 +95,38 @@ class End2EndModel(BaseBackendModel):
|
|||
cls_score = self.wrapper({self.input_name:
|
||||
inputs})[self.output_names[0]]
|
||||
|
||||
from mmcls.models.heads.cls_head import ClsHead
|
||||
predict = ClsHead._get_predictions(
|
||||
None, cls_score, data_samples=data_samples)
|
||||
from mmcls.models.heads import MultiLabelClsHead
|
||||
from mmcls.structures import ClsDataSample
|
||||
pred_scores = cls_score
|
||||
|
||||
return predict
|
||||
if self.head is None or not isinstance(self.head, MultiLabelClsHead):
|
||||
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
|
||||
|
||||
if data_samples is not None:
|
||||
for data_sample, score, label in zip(data_samples, pred_scores,
|
||||
pred_labels):
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
else:
|
||||
data_samples = []
|
||||
for score, label in zip(pred_scores, pred_labels):
|
||||
data_samples.append(ClsDataSample().set_pred_score(
|
||||
score).set_pred_label(label))
|
||||
else:
|
||||
if data_samples is None:
|
||||
data_samples = [
|
||||
ClsDataSample() for _ in range(cls_score.size(0))
|
||||
]
|
||||
|
||||
for data_sample, score in zip(data_samples, pred_scores):
|
||||
if self.head.thr is not None:
|
||||
# a label is predicted positive if larger than thr
|
||||
label = torch.where(score >= self.head.thr)[0]
|
||||
else:
|
||||
# top-k labels will be predicted positive for any example
|
||||
_, label = score.topk(self.head.topk)
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
|
||||
return data_samples
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('sdk')
|
||||
|
@ -204,6 +242,7 @@ def build_classification_model(
|
|||
backend=backend,
|
||||
backend_files=model_files,
|
||||
device=device,
|
||||
model_cfg=model_cfg,
|
||||
deploy_cfg=deploy_cfg,
|
||||
data_preprocessor=data_preprocessor,
|
||||
**kwargs))
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from mmengine.structures import BaseDataElement
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
@ -32,4 +34,10 @@ def base_classifier__forward(
|
|||
output = self.extract_feat(batch_inputs)
|
||||
if self.head is not None:
|
||||
output = self.head(output)
|
||||
|
||||
from mmcls.models.heads import MultiLabelClsHead
|
||||
if isinstance(self.head, MultiLabelClsHead):
|
||||
output = torch.sigmoid(output)
|
||||
else:
|
||||
output = F.softmax(output, dim=1)
|
||||
return output
|
||||
|
|
|
@ -9,6 +9,10 @@ from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
|
|||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
||||
|
||||
try:
|
||||
from torch.testing import assert_close as torch_assert_close
|
||||
except Exception:
|
||||
from torch.testing import assert_allclose as torch_assert_close
|
||||
try:
|
||||
import_codebase(Codebase.MMCLS)
|
||||
except ImportError:
|
||||
|
@ -77,6 +81,7 @@ def test_baseclassifier_forward():
|
|||
def extract_feat(self, batch_inputs: torch.Tensor):
|
||||
return batch_inputs
|
||||
|
||||
input = torch.rand(1, 1000)
|
||||
backbone_cfg = dict(
|
||||
type='ResNet',
|
||||
depth=18,
|
||||
|
@ -90,8 +95,8 @@ def test_baseclassifier_forward():
|
|||
with RewriterContext({}):
|
||||
backend_output = model(input)
|
||||
|
||||
assert model_output == input
|
||||
assert backend_output == input
|
||||
torch_assert_close(model_output, input)
|
||||
torch_assert_close(backend_output, torch.nn.functional.softmax(input, -1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Reference in New Issue