import torch.nn as nn from mmcv.cnn import kaiming_init, normal_init from .registry import NECKS @NECKS.register_module class LinearNeck(nn.Module): def __init__(self, in_channels, out_channels, with_avg_pool=True): super(LinearNeck, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(in_channels, out_channels) def init_weights(self, init_linear='normal'): assert init_linear in ['normal', 'kaiming'], \ "Undefined init_linear: {}".format(init_linear) for m in self.modules(): if isinstance(m, nn.Linear): if init_linear == 'normal': normal_init(m, std=0.01) else: kaiming_init(m, mode='fan_in', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): assert len(x) == 1 if self.with_avg_pool: x = self.avgpool(x[0]) return [self.fc(x.view(x.size(0), -1))] @NECKS.register_module class NonLinearNeckV0(nn.Module): def __init__(self, in_channels, hid_channels, out_channels, with_avg_pool=True): super(NonLinearNeckV0, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.mlp = nn.Sequential( nn.Linear(in_channels, hid_channels), nn.BatchNorm1d(hid_channels, momentum=0.001, affine=False), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(hid_channels, out_channels), nn.ReLU(inplace=True)) def init_weights(self, init_linear='normal'): assert init_linear in ['normal', 'kaiming'], \ "Undefined init_linear: {}".format(init_linear) for m in self.modules(): if isinstance(m, nn.Linear): if init_linear == 'normal': normal_init(m, std=0.01) else: kaiming_init(m, mode='fan_in', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): assert len(x) == 1 if self.with_avg_pool: x = self.avgpool(x[0]) return [self.mlp(x.view(x.size(0), -1))] @NECKS.register_module class NonLinearNeckV1(nn.Module): def __init__(self, in_channels, hid_channels, out_channels, with_avg_pool=True): super(NonLinearNeckV1, self).__init__() self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.mlp = nn.Sequential( nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True), nn.Linear(hid_channels, out_channels)) def init_weights(self, init_linear='normal'): assert init_linear in ['normal', 'kaiming'], \ "Undefined init_linear: {}".format(init_linear) for m in self.modules(): if isinstance(m, nn.Linear): if init_linear == 'normal': normal_init(m, std=0.01) else: kaiming_init(m, mode='fan_in', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): assert len(x) == 1 if self.with_avg_pool: x = self.avgpool(x[0]) return [self.mlp(x.view(x.size(0), -1))] @NECKS.register_module class AvgPoolNeck(nn.Module): def __init__(self): super(AvgPoolNeck, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) def init_weights(self, **kwargs): pass def forward(self, x): assert len(x) == 1 return [self.avg_pool(x[0])]