[Refactor] Use post_process function to handle pred result processing. (#390)

Use post_process function to handle pred result processing in `simple_test`.
This commit is contained in:
Ma Zerun 2021-08-12 11:54:24 +08:00 committed by GitHub
parent e4188c828c
commit f8f1700860
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 10 additions and 26 deletions

View File

@ -63,7 +63,9 @@ class ClsHead(BaseHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return self.post_process(pred)
def post_process(self, pred):
on_trace = is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred

View File

@ -1,9 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import HEADS
from ..utils import is_tracing
from .cls_head import ClsHead
@ -43,11 +41,7 @@ class LinearClsHead(ClsHead):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
on_trace = is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred
return self.post_process(pred)
def forward_train(self, x, gt_label):
cls_score = self.fc(x)

View File

@ -49,6 +49,9 @@ class MultiLabelClsHead(BaseHead):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
return self.post_process(pred)
def post_process(self, pred):
on_trace = is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred

View File

@ -1,9 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import HEADS
from ..utils import is_tracing
from .multi_label_head import MultiLabelClsHead
@ -53,8 +51,4 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
on_trace = is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred
return self.post_process(pred)

View File

@ -1,6 +1,5 @@
from typing import Dict, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, build_norm_layer
@ -122,10 +121,8 @@ class StackedLinearClsHead(ClsHead):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
return pred
pred = list(pred.detach().cpu().numpy())
return pred
return self.post_process(pred)
def forward_train(self, x, gt_label):
cls_score = x

View File

@ -1,12 +1,10 @@
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, constant_init, kaiming_init
from ..builder import HEADS
from ..utils import is_tracing
from .cls_head import ClsHead
@ -70,11 +68,7 @@ class VisionTransformerClsHead(ClsHead):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
on_trace = is_tracing()
if torch.onnx.is_in_onnx_export() or on_trace:
return pred
pred = list(pred.detach().cpu().numpy())
return pred
return self.post_process(pred)
def forward_train(self, x, gt_label):
cls_score = self.layers(x)