from __future__ import division, absolute_import from torch import nn from torch.nn import functional as F ########## # Basic layers ########## class ConvLayer(nn.Module): """Convolution layer (conv + bn + relu).""" def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, IN=False ): super(ConvLayer, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False, groups=groups ) if IN: self.bn = nn.InstanceNorm2d(out_channels, affine=True) else: self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) return self.relu(x) class Conv1x1(nn.Module): """1x1 convolution + bn + relu.""" def __init__(self, in_channels, out_channels, stride=1, groups=1): super(Conv1x1, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=stride, padding=0, bias=False, groups=groups ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) return self.relu(x) class Conv1x1Linear(nn.Module): """1x1 convolution + bn (w/o non-linearity).""" def __init__(self, in_channels, out_channels, stride=1, bn=True): super(Conv1x1Linear, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=stride, padding=0, bias=False ) self.bn = None if bn: self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) return x class Conv3x3(nn.Module): """3x3 convolution + bn + relu.""" def __init__(self, in_channels, out_channels, stride=1, groups=1): super(Conv3x3, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 3, stride=stride, padding=1, bias=False, groups=groups ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) return self.relu(x) class LightConv3x3(nn.Module): """Lightweight 3x3 convolution. 1x1 (linear) + dw 3x3 (nonlinear). """ def __init__(self, in_channels, out_channels): super(LightConv3x3, self).__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, 1, stride=1, padding=0, bias=False ) self.conv2 = nn.Conv2d( out_channels, out_channels, 3, stride=1, padding=1, bias=False, groups=out_channels ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.bn(x) return self.relu(x) class LightConvStream(nn.Module): """Lightweight convolution stream.""" def __init__(self, in_channels, out_channels, depth): super(LightConvStream, self).__init__() assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format( depth ) layers = [] layers += [LightConv3x3(in_channels, out_channels)] for i in range(depth - 1): layers += [LightConv3x3(out_channels, out_channels)] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) ########## # Building blocks for omni-scale feature learning ########## class ChannelGate(nn.Module): """A mini-network that generates channel-wise gates conditioned on input tensor.""" def __init__( self, in_channels, num_gates=None, return_gates=False, gate_activation='sigmoid', reduction=16, layer_norm=False ): super(ChannelGate, self).__init__() if num_gates is None: num_gates = in_channels self.return_gates = return_gates self.global_avgpool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d( in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0 ) self.norm1 = None if layer_norm: self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d( in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0 ) if gate_activation == 'sigmoid': self.gate_activation = nn.Sigmoid() elif gate_activation == 'relu': self.gate_activation = nn.ReLU(inplace=True) elif gate_activation == 'linear': self.gate_activation = None else: raise RuntimeError( "Unknown gate activation: {}".format(gate_activation) ) def forward(self, x): input = x x = self.global_avgpool(x) x = self.fc1(x) if self.norm1 is not None: x = self.norm1(x) x = self.relu(x) x = self.fc2(x) if self.gate_activation is not None: x = self.gate_activation(x) if self.return_gates: return x return input * x class OSBlock(nn.Module): """Omni-scale feature learning block.""" def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): super(OSBlock, self).__init__() assert T >= 1 assert out_channels >= reduction and out_channels % reduction == 0 mid_channels = out_channels // reduction self.conv1 = Conv1x1(in_channels, mid_channels) self.conv2 = nn.ModuleList() for t in range(1, T + 1): self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] self.gate = ChannelGate(mid_channels) self.conv3 = Conv1x1Linear(mid_channels, out_channels) self.downsample = None if in_channels != out_channels: self.downsample = Conv1x1Linear(in_channels, out_channels) def forward(self, x): identity = x x1 = self.conv1(x) x2 = 0 for conv2_t in self.conv2: x2_t = conv2_t(x1) x2 = x2 + self.gate(x2_t) x3 = self.conv3(x2) if self.downsample is not None: identity = self.downsample(identity) out = x3 + identity return F.relu(out) class OSBlockINv1(nn.Module): """Omni-scale feature learning block with instance normalization.""" def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): super(OSBlockINv1, self).__init__() assert T >= 1 assert out_channels >= reduction and out_channels % reduction == 0 mid_channels = out_channels // reduction self.conv1 = Conv1x1(in_channels, mid_channels) self.conv2 = nn.ModuleList() for t in range(1, T + 1): self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] self.gate = ChannelGate(mid_channels) self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False) self.downsample = None if in_channels != out_channels: self.downsample = Conv1x1Linear(in_channels, out_channels) self.IN = nn.InstanceNorm2d(out_channels, affine=True) def forward(self, x): identity = x x1 = self.conv1(x) x2 = 0 for conv2_t in self.conv2: x2_t = conv2_t(x1) x2 = x2 + self.gate(x2_t) x3 = self.conv3(x2) x3 = self.IN(x3) # IN inside residual if self.downsample is not None: identity = self.downsample(identity) out = x3 + identity return F.relu(out) class OSBlockINv2(nn.Module): """Omni-scale feature learning block with instance normalization.""" def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): super(OSBlockINv2, self).__init__() assert T >= 1 assert out_channels >= reduction and out_channels % reduction == 0 mid_channels = out_channels // reduction self.conv1 = Conv1x1(in_channels, mid_channels) self.conv2 = nn.ModuleList() for t in range(1, T + 1): self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] self.gate = ChannelGate(mid_channels) self.conv3 = Conv1x1Linear(mid_channels, out_channels) self.downsample = None if in_channels != out_channels: self.downsample = Conv1x1Linear(in_channels, out_channels) self.IN = nn.InstanceNorm2d(out_channels, affine=True) def forward(self, x): identity = x x1 = self.conv1(x) x2 = 0 for conv2_t in self.conv2: x2_t = conv2_t(x1) x2 = x2 + self.gate(x2_t) x3 = self.conv3(x2) if self.downsample is not None: identity = self.downsample(identity) out = x3 + identity out = self.IN(out) # IN outside residual return F.relu(out) class OSBlockINv3(nn.Module): """Omni-scale feature learning block with instance normalization.""" def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): super(OSBlockINv3, self).__init__() assert T >= 1 assert out_channels >= reduction and out_channels % reduction == 0 mid_channels = out_channels // reduction self.conv1 = Conv1x1(in_channels, mid_channels) self.conv2 = nn.ModuleList() for t in range(1, T + 1): self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] self.gate = ChannelGate(mid_channels) self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False) self.downsample = None if in_channels != out_channels: self.downsample = Conv1x1Linear(in_channels, out_channels) self.IN_in = nn.InstanceNorm2d(out_channels, affine=True) self.IN_out = nn.InstanceNorm2d(out_channels, affine=True) def forward(self, x): identity = x x1 = self.conv1(x) x2 = 0 for conv2_t in self.conv2: x2_t = conv2_t(x1) x2 = x2 + self.gate(x2_t) x3 = self.conv3(x2) x3 = self.IN_in(x3) # IN inside residual if self.downsample is not None: identity = self.downsample(identity) out = x3 + identity out = self.IN_out(out) # IN outside residual return F.relu(out) ########## # Network architecture ########## class OSNet(nn.Module): """Omni-Scale Network. Reference: - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019. - Zhou et al. Learning Generalisable Omni-Scale Representations for Person Re-Identification. TPAMI, 2021. """ def __init__( self, num_classes, blocks, layers, channels, feature_dim=512, loss='softmax', conv1_IN=True, **kwargs ): super(OSNet, self).__init__() num_blocks = len(blocks) assert num_blocks == len(layers) assert num_blocks == len(channels) - 1 self.loss = loss self.feature_dim = feature_dim # convolutional backbone self.conv1 = ConvLayer( 3, channels[0], 7, stride=2, padding=3, IN=conv1_IN ) self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) self.conv2 = self._make_layer( blocks[0], layers[0], channels[0], channels[1] ) self.pool2 = nn.Sequential( Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2) ) self.conv3 = self._make_layer( blocks[1], layers[1], channels[1], channels[2] ) self.pool3 = nn.Sequential( Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2) ) self.conv4 = self._make_layer( blocks[2], layers[2], channels[2], channels[3] ) self.conv5 = Conv1x1(channels[3], channels[3]) self.global_avgpool = nn.AdaptiveAvgPool2d(1) # fully connected layer self.fc = self._construct_fc_layer( self.feature_dim, channels[3], dropout_p=None ) # identity classification layer self.classifier = nn.Linear(self.feature_dim, num_classes) self._init_params() def _make_layer(self, blocks, layer, in_channels, out_channels): layers = [] layers += [blocks[0](in_channels, out_channels)] for i in range(1, len(blocks)): layers += [blocks[i](out_channels, out_channels)] return nn.Sequential(*layers) def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): if fc_dims is None or fc_dims < 0: self.feature_dim = input_dim return None if isinstance(fc_dims, int): fc_dims = [fc_dims] layers = [] for dim in fc_dims: layers.append(nn.Linear(input_dim, dim)) layers.append(nn.BatchNorm1d(dim)) layers.append(nn.ReLU(inplace=True)) if dropout_p is not None: layers.append(nn.Dropout(p=dropout_p)) input_dim = dim self.feature_dim = fc_dims[-1] return nn.Sequential(*layers) def _init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu' ) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.InstanceNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def featuremaps(self, x): x = self.conv1(x) x = self.maxpool(x) x = self.conv2(x) x = self.pool2(x) x = self.conv3(x) x = self.pool3(x) x = self.conv4(x) return self.conv5(x) def forward(self, x, return_featuremaps=False, **kwargs): x = self.featuremaps(x) if return_featuremaps: return x v = self.global_avgpool(x) v = v.view(v.size(0), -1) if self.fc is not None: v = self.fc(v) if not self.training: return v y = self.classifier(v) if self.loss == 'softmax': return y elif self.loss == 'triplet': return y, v else: raise KeyError("Unsupported loss: {}".format(self.loss)) ########## # Instantiation ########## def osnet_ain_x1_0( num_classes=1000, pretrained=True, loss='softmax', **kwargs ): model = OSNet( num_classes, blocks=[ [OSBlockINv1, OSBlockINv1], [OSBlock, OSBlockINv1], [OSBlockINv1, OSBlock] ], layers=[2, 2, 2], channels=[64, 256, 384, 512], loss=loss, conv1_IN=True, **kwargs ) return model __models = {'osnet_ain_x1_0': osnet_ain_x1_0} def build_model(name, num_classes=100): avai_models = list(__models.keys()) if name not in avai_models: raise KeyError( 'Unknown model: {}. Must be one of {}'.format(name, avai_models) ) return __models[name](num_classes=num_classes)