parent
8633993ca3
commit
9b070a5dda
|
@ -1,13 +1,4 @@
|
|||
from .cls_head import simple_test_of_cls_head
|
||||
from .linear_head import simple_test_of_linear_head
|
||||
from .multi_label_head import simple_test_of_multi_label_head
|
||||
from .multi_label_linear_head import simple_test_of_multi_label_linear_head
|
||||
from .stacked_head import simple_test_of_stacked_head
|
||||
from .vision_transformer_head import simple_test_of_vision_transformer_head
|
||||
from .cls_head import post_process_of_cls_head
|
||||
from .multi_label_head import post_process_of_multi_label_head
|
||||
|
||||
__all__ = [
|
||||
'simple_test_of_multi_label_linear_head',
|
||||
'simple_test_of_multi_label_head', 'simple_test_of_cls_head',
|
||||
'simple_test_of_linear_head', 'simple_test_of_stacked_head',
|
||||
'simple_test_of_vision_transformer_head'
|
||||
]
|
||||
__all__ = ['post_process_of_cls_head', 'post_process_of_multi_label_head']
|
||||
|
|
|
@ -1,13 +1,7 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.ClsHead.simple_test')
|
||||
def simple_test_of_cls_head(ctx, self, cls_score, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
func_name='mmcls.models.heads.ClsHead.post_process')
|
||||
def post_process_of_cls_head(ctx, self, pred, **kwargs):
|
||||
return pred
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.LinearClsHead.simple_test')
|
||||
def simple_test_of_linear_head(ctx, self, img, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
cls_score = self.fc(img)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return pred
|
|
@ -1,12 +1,7 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.MultiLabelClsHead.simple_test')
|
||||
def simple_test_of_multi_label_head(ctx, self, cls_score, **kwargs):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
func_name='mmcls.models.heads.MultiLabelClsHead.post_process')
|
||||
def post_process_of_multi_label_head(ctx, self, pred, **kwargs):
|
||||
return pred
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.MultiLabelLinearClsHead.simple_test')
|
||||
def simple_test_of_multi_label_linear_head(ctx, self, img, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
cls_score = self.fc(img)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
return pred
|
|
@ -1,16 +0,0 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.StackedLinearClsHead.simple_test')
|
||||
def simple_test_of_stacked_head(ctx, self, img, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
cls_score = img
|
||||
for layer in self.layers:
|
||||
cls_score = layer(cls_score)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return pred
|
|
@ -1,14 +0,0 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.VisionTransformerClsHead.simple_test')
|
||||
def simple_test_of_vision_transformer_head(ctx, self, img, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
cls_score = self.layers(img)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return pred
|
Loading…
Reference in New Issue