2020-07-08 23:18:54 +08:00

61 lines
2.1 KiB
Python

import torch.nn as nn
from mmcv.cnn import kaiming_init, normal_init
from ..utils import accuracy
from ..registry import HEADS
@HEADS.register_module
class ClsHead(nn.Module):
"""Simplest classifier head, with only one fc layer.
"""
def __init__(self,
with_avg_pool=False,
in_channels=2048,
num_classes=1000):
super(ClsHead, self).__init__()
self.with_avg_pool = with_avg_pool
self.in_channels = in_channels
self.num_classes = num_classes
self.criterion = nn.CrossEntropyLoss()
if self.with_avg_pool:
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc_cls = nn.Linear(in_channels, num_classes)
def init_weights(self, init_linear='normal', std=0.01, bias=0.):
assert init_linear in ['normal', 'kaiming'], \
"Undefined init_linear: {}".format(init_linear)
for m in self.modules():
if isinstance(m, nn.Linear):
if init_linear == 'normal':
normal_init(m, std=std, bias=bias)
else:
kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m,
(nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
assert isinstance(x, (tuple, list)) and len(x) == 1
x = x[0]
if self.with_avg_pool:
assert x.dim() == 4, \
"Tensor must has 4 dims, got: {}".format(x.dim())
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
cls_score = self.fc_cls(x)
return [cls_score]
def loss(self, cls_score, labels):
losses = dict()
assert isinstance(cls_score, (tuple, list)) and len(cls_score) == 1
losses['loss'] = self.criterion(cls_score[0], labels)
losses['acc'] = accuracy(cls_score[0], labels)
return losses