from __future__ import division, absolute_import import torch from torch import nn from torch.nn import functional as F __all__ = ['osnet_avgpool', 'osnet_maxpool'] ########## # Basic layers ########## class ConvLayer(nn.Module): """Convolution layer.""" def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1 ): super(ConvLayer, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False, groups=groups ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class Conv1x1(nn.Module): """1x1 convolution.""" def __init__(self, in_channels, out_channels, stride=1, groups=1): super(Conv1x1, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=stride, padding=0, bias=False, groups=groups ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class Conv1x1Linear(nn.Module): """1x1 convolution without non-linearity.""" def __init__(self, in_channels, out_channels, stride=1): super(Conv1x1Linear, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=stride, padding=0, bias=False ) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): x = self.conv(x) x = self.bn(x) return x class Conv3x3(nn.Module): """3x3 convolution.""" def __init__(self, in_channels, out_channels, stride=1, groups=1): super(Conv3x3, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 3, stride=stride, padding=1, bias=False, groups=groups ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class LightConv3x3(nn.Module): """Lightweight 3x3 convolution. 1x1 (linear) + dw 3x3 (nonlinear). """ def __init__(self, in_channels, out_channels): super(LightConv3x3, self).__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, 1, stride=1, padding=0, bias=False ) self.conv2 = nn.Conv2d( out_channels, out_channels, 3, stride=1, padding=1, bias=False, groups=out_channels ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.bn(x) x = self.relu(x) return x ########## # Building blocks for omni-scale feature learning ########## class ChannelGate(nn.Module): """A mini-network that generates channel-wise gates conditioned on input.""" def __init__( self, in_channels, num_gates=None, return_gates=False, gate_activation='sigmoid', reduction=16, layer_norm=False ): super(ChannelGate, self).__init__() if num_gates is None: num_gates = in_channels self.return_gates = return_gates self.global_avgpool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d( in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0 ) self.norm1 = None if layer_norm: self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d( in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0 ) if gate_activation == 'sigmoid': self.gate_activation = nn.Sigmoid() elif gate_activation == 'relu': self.gate_activation = nn.ReLU(inplace=True) elif gate_activation == 'linear': self.gate_activation = None else: raise RuntimeError( "Unknown gate activation: {}".format(gate_activation) ) def forward(self, x): input = x x = self.global_avgpool(x) x = self.fc1(x) if self.norm1 is not None: x = self.norm1(x) x = self.relu(x) x = self.fc2(x) if self.gate_activation is not None: x = self.gate_activation(x) if self.return_gates: return x return input * x class OSBlock(nn.Module): """Omni-scale feature learning block.""" def __init__(self, in_channels, out_channels, **kwargs): super(OSBlock, self).__init__() mid_channels = out_channels // 4 self.conv1 = Conv1x1(in_channels, mid_channels) self.conv2a = LightConv3x3(mid_channels, mid_channels) self.conv2b = nn.Sequential( LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), ) self.conv2c = nn.Sequential( LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), ) self.conv2d = nn.Sequential( LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), LightConv3x3(mid_channels, mid_channels), ) self.gate = ChannelGate(mid_channels) self.conv3 = Conv1x1Linear(mid_channels, out_channels) self.downsample = None if in_channels != out_channels: self.downsample = Conv1x1Linear(in_channels, out_channels) def forward(self, x): residual = x x1 = self.conv1(x) x2a = self.conv2a(x1) x2b = self.conv2b(x1) x2c = self.conv2c(x1) x2d = self.conv2d(x1) x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d) x3 = self.conv3(x2) if self.downsample is not None: residual = self.downsample(residual) out = x3 + residual return F.relu(out) ########## # Network architecture ########## class BaseNet(nn.Module): def _make_layer( self, block, layer, in_channels, out_channels, reduce_spatial_size ): layers = [] layers.append(block(in_channels, out_channels)) for i in range(1, layer): layers.append(block(out_channels, out_channels)) if reduce_spatial_size: layers.append( nn.Sequential( Conv1x1(out_channels, out_channels), nn.AvgPool2d(2, stride=2) ) ) return nn.Sequential(*layers) def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): if fc_dims is None or fc_dims < 0: self.feature_dim = input_dim return None if isinstance(fc_dims, int): fc_dims = [fc_dims] layers = [] for dim in fc_dims: layers.append(nn.Linear(input_dim, dim)) layers.append(nn.BatchNorm1d(dim)) layers.append(nn.ReLU(inplace=True)) if dropout_p is not None: layers.append(nn.Dropout(p=dropout_p)) input_dim = dim self.feature_dim = fc_dims[-1] return nn.Sequential(*layers) def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu' ) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) class OSNet(BaseNet): def __init__( self, num_classes, blocks, layers, channels, feature_dim=512, loss='softmax', pool='avg', **kwargs ): super(OSNet, self).__init__() num_blocks = len(blocks) assert num_blocks == len(layers) assert num_blocks == len(channels) - 1 self.loss = loss # convolutional backbone self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3) self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) self.conv2 = self._make_layer( blocks[0], layers[0], channels[0], channels[1], reduce_spatial_size=True ) self.conv3 = self._make_layer( blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True ) self.conv4 = self._make_layer( blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False ) self.conv5 = Conv1x1(channels[3], channels[3]) if pool == 'avg': self.global_pool = nn.AdaptiveAvgPool2d(1) else: self.global_pool = nn.AdaptiveMaxPool2d(1) # fully connected layer self.fc = self._construct_fc_layer( feature_dim, channels[3], dropout_p=None ) # classification layer self.classifier = nn.Linear(self.feature_dim, num_classes) self.init_params() def featuremaps(self, x): x = self.conv1(x) x = self.maxpool(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.conv5(x) return x def forward(self, x): x = self.featuremaps(x) v = self.global_pool(x) v = v.view(v.size(0), -1) if self.fc is not None: v = self.fc(v) y = self.classifier(v) if not self.training: y = torch.sigmoid(y) return y def osnet_avgpool(num_classes=1000, loss='softmax', **kwargs): return OSNet( num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2], channels=[64, 256, 384, 512], loss=loss, pool='avg', **kwargs ) def osnet_maxpool(num_classes=1000, loss='softmax', **kwargs): return OSNet( num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2], channels=[64, 256, 384, 512], loss=loss, pool='max', **kwargs )