[Enhancement] Refine mmcls rewriting (#106)

* fix mmcls head

* fix lint
pull/12/head
AllentDan 2021-09-28 19:21:51 +08:00 committed by GitHub
parent 8633993ca3
commit 9b070a5dda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 7 additions and 85 deletions

View File

@ -1,13 +1,4 @@
from .cls_head import simple_test_of_cls_head
from .linear_head import simple_test_of_linear_head
from .multi_label_head import simple_test_of_multi_label_head
from .multi_label_linear_head import simple_test_of_multi_label_linear_head
from .stacked_head import simple_test_of_stacked_head
from .vision_transformer_head import simple_test_of_vision_transformer_head
from .cls_head import post_process_of_cls_head
from .multi_label_head import post_process_of_multi_label_head
__all__ = [
'simple_test_of_multi_label_linear_head',
'simple_test_of_multi_label_head', 'simple_test_of_cls_head',
'simple_test_of_linear_head', 'simple_test_of_stacked_head',
'simple_test_of_vision_transformer_head'
]
__all__ = ['post_process_of_cls_head', 'post_process_of_multi_label_head']

View File

@ -1,13 +1,7 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.ClsHead.simple_test')
def simple_test_of_cls_head(ctx, self, cls_score, **kwargs):
"""Test without augmentation."""
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
func_name='mmcls.models.heads.ClsHead.post_process')
def post_process_of_cls_head(ctx, self, pred, **kwargs):
return pred

View File

@ -1,14 +0,0 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.LinearClsHead.simple_test')
def simple_test_of_linear_head(ctx, self, img, **kwargs):
"""Test without augmentation."""
cls_score = self.fc(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return pred

View File

@ -1,12 +1,7 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.MultiLabelClsHead.simple_test')
def simple_test_of_multi_label_head(ctx, self, cls_score, **kwargs):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
func_name='mmcls.models.heads.MultiLabelClsHead.post_process')
def post_process_of_multi_label_head(ctx, self, pred, **kwargs):
return pred

View File

@ -1,14 +0,0 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.MultiLabelLinearClsHead.simple_test')
def simple_test_of_multi_label_linear_head(ctx, self, img, **kwargs):
"""Test without augmentation."""
cls_score = self.fc(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
return pred

View File

@ -1,16 +0,0 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.StackedLinearClsHead.simple_test')
def simple_test_of_stacked_head(ctx, self, img, **kwargs):
"""Test without augmentation."""
cls_score = img
for layer in self.layers:
cls_score = layer(cls_score)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return pred

View File

@ -1,14 +0,0 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.VisionTransformerClsHead.simple_test')
def simple_test_of_vision_transformer_head(ctx, self, img, **kwargs):
"""Test without augmentation."""
cls_score = self.layers(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return pred