mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Refactor] Use post_process function to handle pred result processing. (#390)
Use post_process function to handle pred result processing in `simple_test`.
This commit is contained in:
parent
e4188c828c
commit
f8f1700860
@ -63,7 +63,9 @@ class ClsHead(BaseHead):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return self.post_process(pred)
|
||||
|
||||
def post_process(self, pred):
|
||||
on_trace = is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
|
@ -1,9 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import HEADS
|
||||
from ..utils import is_tracing
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
@ -43,11 +41,7 @@ class LinearClsHead(ClsHead):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
|
||||
on_trace = is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
return self.post_process(pred)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
cls_score = self.fc(x)
|
||||
|
@ -49,6 +49,9 @@ class MultiLabelClsHead(BaseHead):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
|
||||
return self.post_process(pred)
|
||||
|
||||
def post_process(self, pred):
|
||||
on_trace = is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
|
@ -1,9 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import HEADS
|
||||
from ..utils import is_tracing
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
|
||||
|
||||
@ -53,8 +51,4 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
|
||||
on_trace = is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
return self.post_process(pred)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
@ -122,10 +121,8 @@ class StackedLinearClsHead(ClsHead):
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
||||
return self.post_process(pred)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
cls_score = x
|
||||
|
@ -1,12 +1,10 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_activation_layer, constant_init, kaiming_init
|
||||
|
||||
from ..builder import HEADS
|
||||
from ..utils import is_tracing
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
@ -70,11 +68,7 @@ class VisionTransformerClsHead(ClsHead):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
|
||||
on_trace = is_tracing()
|
||||
if torch.onnx.is_in_onnx_export() or on_trace:
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
return self.post_process(pred)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
cls_score = self.layers(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user