from __future__ import absolute_import, division, print_function from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import constant_init, kaiming_init from torch.nn.modules.batchnorm import _BatchNorm from ..modelzoo import inceptionv4 as model_urls from ..registry import BACKBONES __all__ = ['Inception4'] class BasicConv2d(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d( in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) # verify bias false self.bn = nn.BatchNorm2d( out_planes, eps=0.001, # value found in tensorflow momentum=0.1, # default pytorch value affine=True) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class Mixed_3a(nn.Module): def __init__(self): super(Mixed_3a, self).__init__() self.maxpool = nn.MaxPool2d(3, stride=2) self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) def forward(self, x): x0 = self.maxpool(x) x1 = self.conv(x) out = torch.cat((x0, x1), 1) return out class Mixed_4a(nn.Module): def __init__(self): super(Mixed_4a, self).__init__() self.branch0 = nn.Sequential( BasicConv2d(160, 64, kernel_size=1, stride=1), BasicConv2d(64, 96, kernel_size=3, stride=1)) self.branch1 = nn.Sequential( BasicConv2d(160, 64, kernel_size=1, stride=1), BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)) def forward(self, x): x0 = self.branch0(x) x1 = self.branch1(x) out = torch.cat((x0, x1), 1) return out class Mixed_5a(nn.Module): def __init__(self): super(Mixed_5a, self).__init__() self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) self.maxpool = nn.MaxPool2d(3, stride=2) def forward(self, x): x0 = self.conv(x) x1 = self.maxpool(x) out = torch.cat((x0, x1), 1) return out class Inception_A(nn.Module): def __init__(self): super(Inception_A, self).__init__() self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) self.branch1 = nn.Sequential( BasicConv2d(384, 64, kernel_size=1, stride=1), BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)) self.branch2 = nn.Sequential( BasicConv2d(384, 64, kernel_size=1, stride=1), BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), BasicConv2d(384, 96, kernel_size=1, stride=1)) def forward(self, x): x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) out = torch.cat((x0, x1, x2, x3), 1) return out class Reduction_A(nn.Module): def __init__(self): super(Reduction_A, self).__init__() self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) self.branch1 = nn.Sequential( BasicConv2d(384, 192, kernel_size=1, stride=1), BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), BasicConv2d(224, 256, kernel_size=3, stride=2)) self.branch2 = nn.MaxPool2d(3, stride=2) def forward(self, x): x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) out = torch.cat((x0, x1, x2), 1) return out class Inception_B(nn.Module): def __init__(self): super(Inception_B, self).__init__() self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) self.branch1 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), BasicConv2d( 192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), BasicConv2d( 224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))) self.branch2 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), BasicConv2d( 192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), BasicConv2d( 192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), BasicConv2d( 224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), BasicConv2d( 224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), BasicConv2d(1024, 128, kernel_size=1, stride=1)) def forward(self, x): x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) out = torch.cat((x0, x1, x2, x3), 1) return out class Reduction_B(nn.Module): def __init__(self): super(Reduction_B, self).__init__() self.branch0 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), BasicConv2d(192, 192, kernel_size=3, stride=2)) self.branch1 = nn.Sequential( BasicConv2d(1024, 256, kernel_size=1, stride=1), BasicConv2d( 256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), BasicConv2d( 256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), BasicConv2d(320, 320, kernel_size=3, stride=2)) self.branch2 = nn.MaxPool2d(3, stride=2) def forward(self, x): x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) out = torch.cat((x0, x1, x2), 1) return out class Inception_C(nn.Module): def __init__(self): super(Inception_C, self).__init__() self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) self.branch1_1a = BasicConv2d( 384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) self.branch1_1b = BasicConv2d( 384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) self.branch2_1 = BasicConv2d( 384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.branch2_2 = BasicConv2d( 448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) self.branch2_3a = BasicConv2d( 512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) self.branch2_3b = BasicConv2d( 512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), BasicConv2d(1536, 256, kernel_size=1, stride=1)) def forward(self, x): x0 = self.branch0(x) x1_0 = self.branch1_0(x) x1_1a = self.branch1_1a(x1_0) x1_1b = self.branch1_1b(x1_0) x1 = torch.cat((x1_1a, x1_1b), 1) x2_0 = self.branch2_0(x) x2_1 = self.branch2_1(x2_0) x2_2 = self.branch2_2(x2_1) x2_3a = self.branch2_3a(x2_2) x2_3b = self.branch2_3b(x2_2) x2 = torch.cat((x2_3a, x2_3b), 1) x3 = self.branch3(x) out = torch.cat((x0, x1, x2, x3), 1) return out class InceptionAux(nn.Module): def __init__(self, in_channels, num_classes): super(InceptionAux, self).__init__() self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) self.conv1 = BasicConv2d(128, 768, kernel_size=5) self.conv1.stddev = 0.01 self.fc = nn.Linear(768, num_classes) self.fc.stddev = 0.001 def forward(self, x): # N x 768 x 17 x 17 x = F.avg_pool2d(x, kernel_size=5, stride=3) # N x 768 x 5 x 5 x = self.conv0(x) # N x 128 x 5 x 5 x = self.conv1(x) # N x 768 x 1 x 1 # Adaptive average pooling x = F.adaptive_avg_pool2d(x, (1, 1)) # N x 768 x 1 x 1 x = torch.flatten(x, 1) # N x 768 x = self.fc(x) # N x 1000 return x # class BasicConv2d(nn.Module): # def __init__(self, in_channels, out_channels, **kwargs): # super(BasicConv2d, self).__init__() # self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) # self.bn = nn.BatchNorm2d(out_channels, eps=0.001) # def forward(self, x): # x = self.conv(x) # x = self.bn(x) # return F.relu(x, inplace=True) @BACKBONES.register_module class Inception4(nn.Module): """InceptionV4 backbone. Args: num_classes (int): The num_classes of InceptionV4. An extra fc will be used if """ def __init__(self, num_classes: int = 0, p_dropout=0.2, aux_logits: bool = True): super(Inception4, self).__init__() self.aux_logits = aux_logits # Modules self.features = nn.Sequential( BasicConv2d(3, 32, kernel_size=3, stride=2), BasicConv2d(32, 32, kernel_size=3, stride=1), BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), Mixed_3a(), Mixed_4a(), Mixed_5a(), Inception_A(), Inception_A(), Inception_A(), Inception_A(), Reduction_A(), # Mixed_6a Inception_B(), Inception_B(), Inception_B(), Inception_B(), Inception_B(), Inception_B(), Inception_B(), # Mixed_6h 1024 x 17 x 17 Reduction_B(), # Mixed_7a Inception_C(), Inception_C(), Inception_C()) if aux_logits: self.AuxLogits = InceptionAux(1024, num_classes) self.dropout = nn.Dropout(p_dropout) self.last_linear = None if num_classes > 0: self.last_linear = nn.Linear(1536, num_classes) self.default_pretrained_model_path = model_urls[ self.__class__.__name__] @property def fc(self): return self.last_linear def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m, mode='fan_in', nonlinearity='relu') elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def logits(self, features): x = F.adaptive_avg_pool2d(features, output_size=(1, 1)) # x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) x = x.view(x.size(0), -1) # B x 1536 x = self.fc(x) # B x num_classes return x def forward(self, input: torch.Tensor): """_summary_ Args: input (torch.Tensor): A RGB image tensor with shape B x C x H x W Returns: torch.Tensor: A feature tensor or a logit tensor when num_classes is 0 (default) """ if self.training and self.aux_logits: x = self.features[:-4](input) aux = self.AuxLogits(x) x = self.features[-4:](x) else: x = self.features(input) aux = None if self.fc is not None: x = self.logits(x) return [aux, x]