Add softmax to cls model (#1573)

* Add softmax to cls model

* fix cls ci

* multihead

* update classification_model.py
pull/1597/head
q.yao 2022-12-30 15:43:12 +08:00 committed by GitHub
parent baa86aa4a5
commit deaefacea1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 8 deletions

View File

@ -333,8 +333,8 @@ class Classification(BaseTask):
if 'topk' not in postprocess:
topk = (1, )
logger = get_root_logger()
logger.warning('no topk in postprocess config, using default \
topk value.')
logger.warning('no topk in postprocess config, using default '
'topk value.')
else:
topk = postprocess.topk
postprocess.topk = max(topk)

View File

@ -35,6 +35,7 @@ class End2EndModel(BaseBackendModel):
backend: Backend,
backend_files: Sequence[str],
device: str,
model_cfg: Union[str, Config] = None,
deploy_cfg: Union[str, Config] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
**kwargs):
@ -46,8 +47,18 @@ class End2EndModel(BaseBackendModel):
backend_files=backend_files,
device=device,
**kwargs)
self.model_cfg = model_cfg
self.head = None
if model_cfg is not None:
self.head = self._get_head()
self.device = device
def _get_head(self):
from mmcls.models import build_head
head_config = self.model_cfg['model']['head']
head = build_head(head_config)
return head
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
device: str, **kwargs):
output_names = self.output_names
@ -84,11 +95,38 @@ class End2EndModel(BaseBackendModel):
cls_score = self.wrapper({self.input_name:
inputs})[self.output_names[0]]
from mmcls.models.heads.cls_head import ClsHead
predict = ClsHead._get_predictions(
None, cls_score, data_samples=data_samples)
from mmcls.models.heads import MultiLabelClsHead
from mmcls.structures import ClsDataSample
pred_scores = cls_score
return predict
if self.head is None or not isinstance(self.head, MultiLabelClsHead):
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
if data_samples is not None:
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
data_sample.set_pred_score(score).set_pred_label(label)
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score(
score).set_pred_label(label))
else:
if data_samples is None:
data_samples = [
ClsDataSample() for _ in range(cls_score.size(0))
]
for data_sample, score in zip(data_samples, pred_scores):
if self.head.thr is not None:
# a label is predicted positive if larger than thr
label = torch.where(score >= self.head.thr)[0]
else:
# top-k labels will be predicted positive for any example
_, label = score.topk(self.head.topk)
data_sample.set_pred_score(score).set_pred_label(label)
return data_samples
@__BACKEND_MODEL.register_module('sdk')
@ -204,6 +242,7 @@ def build_classification_model(
backend=backend,
backend_files=model_files,
device=device,
model_cfg=model_cfg,
deploy_cfg=deploy_cfg,
data_preprocessor=data_preprocessor,
**kwargs))

View File

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from mmengine.structures import BaseDataElement
from torch import Tensor
from torch.nn import functional as F
from mmdeploy.core import FUNCTION_REWRITER
@ -32,4 +34,10 @@ def base_classifier__forward(
output = self.extract_feat(batch_inputs)
if self.head is not None:
output = self.head(output)
from mmcls.models.heads import MultiLabelClsHead
if isinstance(self.head, MultiLabelClsHead):
output = torch.sigmoid(output)
else:
output = F.softmax(output, dim=1)
return output

View File

@ -9,6 +9,10 @@ from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close
try:
import_codebase(Codebase.MMCLS)
except ImportError:
@ -77,6 +81,7 @@ def test_baseclassifier_forward():
def extract_feat(self, batch_inputs: torch.Tensor):
return batch_inputs
input = torch.rand(1, 1000)
backbone_cfg = dict(
type='ResNet',
depth=18,
@ -90,8 +95,8 @@ def test_baseclassifier_forward():
with RewriterContext({}):
backend_output = model(input)
assert model_output == input
assert backend_output == input
torch_assert_close(model_output, input)
torch_assert_close(backend_output, torch.nn.functional.softmax(input, -1))
@pytest.mark.parametrize(