2019-12-01 10:35:44 +08:00
|
|
|
from __future__ import division, absolute_import
|
2019-11-08 21:00:39 +08:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
EPS = 1e-12
|
|
|
|
NORM_AFFINE = False # enable affine transformations for normalization layer
|
|
|
|
|
|
|
|
|
|
|
|
##########
|
|
|
|
# Basic layers
|
|
|
|
##########
|
2020-05-05 22:58:00 +08:00
|
|
|
class IBN(nn.Module):
|
|
|
|
"""Instance + Batch Normalization."""
|
|
|
|
|
|
|
|
def __init__(self, num_channels):
|
|
|
|
super(IBN, self).__init__()
|
|
|
|
half1 = int(num_channels / 2)
|
|
|
|
self.half = half1
|
|
|
|
half2 = num_channels - half1
|
|
|
|
self.IN = nn.InstanceNorm2d(half1, affine=NORM_AFFINE)
|
|
|
|
self.BN = nn.BatchNorm2d(half2, affine=NORM_AFFINE)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
split = torch.split(x, self.half, 1)
|
|
|
|
out1 = self.IN(split[0].contiguous())
|
|
|
|
out2 = self.BN(split[1].contiguous())
|
|
|
|
return torch.cat((out1, out2), 1)
|
|
|
|
|
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
class ConvLayer(nn.Module):
|
|
|
|
"""Convolution layer (conv + bn + relu)."""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
in_channels,
|
|
|
|
out_channels,
|
|
|
|
kernel_size,
|
|
|
|
stride=1,
|
|
|
|
padding=0,
|
|
|
|
groups=1,
|
|
|
|
IN=False
|
|
|
|
):
|
2019-11-08 21:00:39 +08:00
|
|
|
super(ConvLayer, self).__init__()
|
2019-12-01 10:35:44 +08:00
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
in_channels,
|
|
|
|
out_channels,
|
|
|
|
kernel_size,
|
|
|
|
stride=stride,
|
|
|
|
padding=padding,
|
|
|
|
bias=False,
|
|
|
|
groups=groups
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
if IN:
|
|
|
|
self.bn = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
else:
|
|
|
|
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
x = self.bn(x)
|
|
|
|
return self.relu(x)
|
|
|
|
|
|
|
|
|
|
|
|
class Conv1x1(nn.Module):
|
|
|
|
"""1x1 convolution + bn + relu."""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self, in_channels, out_channels, stride=1, groups=1, ibn=False
|
|
|
|
):
|
2019-11-08 21:00:39 +08:00
|
|
|
super(Conv1x1, self).__init__()
|
2019-12-01 10:35:44 +08:00
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
in_channels,
|
|
|
|
out_channels,
|
|
|
|
1,
|
|
|
|
stride=stride,
|
|
|
|
padding=0,
|
|
|
|
bias=False,
|
|
|
|
groups=groups
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
if ibn:
|
|
|
|
self.bn = IBN(out_channels)
|
|
|
|
else:
|
|
|
|
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
x = self.bn(x)
|
|
|
|
return self.relu(x)
|
|
|
|
|
|
|
|
|
|
|
|
class Conv1x1Linear(nn.Module):
|
|
|
|
"""1x1 convolution + bn (w/o non-linearity)."""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
def __init__(self, in_channels, out_channels, stride=1, bn=True):
|
|
|
|
super(Conv1x1Linear, self).__init__()
|
2019-12-01 10:35:44 +08:00
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.bn = None
|
|
|
|
if bn:
|
|
|
|
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
if self.bn is not None:
|
|
|
|
x = self.bn(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class Conv3x3(nn.Module):
|
|
|
|
"""3x3 convolution + bn + relu."""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
|
|
|
super(Conv3x3, self).__init__()
|
2019-12-01 10:35:44 +08:00
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
in_channels,
|
|
|
|
out_channels,
|
|
|
|
3,
|
|
|
|
stride=stride,
|
|
|
|
padding=1,
|
|
|
|
bias=False,
|
|
|
|
groups=groups
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv(x)
|
|
|
|
x = self.bn(x)
|
|
|
|
return self.relu(x)
|
|
|
|
|
|
|
|
|
|
|
|
class LightConv3x3(nn.Module):
|
|
|
|
"""Lightweight 3x3 convolution.
|
|
|
|
|
|
|
|
1x1 (linear) + dw 3x3 (nonlinear).
|
|
|
|
"""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
def __init__(self, in_channels, out_channels):
|
|
|
|
super(LightConv3x3, self).__init__()
|
2019-12-01 10:35:44 +08:00
|
|
|
self.conv1 = nn.Conv2d(
|
|
|
|
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
|
|
|
)
|
|
|
|
self.conv2 = nn.Conv2d(
|
|
|
|
out_channels,
|
|
|
|
out_channels,
|
|
|
|
3,
|
|
|
|
stride=1,
|
|
|
|
padding=1,
|
|
|
|
bias=False,
|
|
|
|
groups=out_channels
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv1(x)
|
|
|
|
x = self.conv2(x)
|
|
|
|
x = self.bn(x)
|
|
|
|
return self.relu(x)
|
|
|
|
|
|
|
|
|
|
|
|
class LightConvStream(nn.Module):
|
|
|
|
"""Lightweight convolution stream."""
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, depth):
|
|
|
|
super(LightConvStream, self).__init__()
|
2019-12-01 10:35:44 +08:00
|
|
|
assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format(
|
|
|
|
depth
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
layers = []
|
|
|
|
layers += [LightConv3x3(in_channels, out_channels)]
|
2019-12-01 10:35:44 +08:00
|
|
|
for i in range(depth - 1):
|
2019-11-08 21:00:39 +08:00
|
|
|
layers += [LightConv3x3(out_channels, out_channels)]
|
|
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.layers(x)
|
|
|
|
|
|
|
|
|
|
|
|
##########
|
|
|
|
# Building blocks for omni-scale feature learning
|
|
|
|
##########
|
|
|
|
class ChannelGate(nn.Module):
|
|
|
|
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
|
|
|
|
2019-12-01 10:35:44 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
in_channels,
|
|
|
|
num_gates=None,
|
|
|
|
return_gates=False,
|
|
|
|
gate_activation='sigmoid',
|
|
|
|
reduction=16,
|
|
|
|
layer_norm=False
|
|
|
|
):
|
2019-11-08 21:00:39 +08:00
|
|
|
super(ChannelGate, self).__init__()
|
|
|
|
if num_gates is None:
|
|
|
|
num_gates = in_channels
|
|
|
|
self.return_gates = return_gates
|
|
|
|
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
2019-12-01 10:35:44 +08:00
|
|
|
self.fc1 = nn.Conv2d(
|
|
|
|
in_channels,
|
|
|
|
in_channels // reduction,
|
|
|
|
kernel_size=1,
|
|
|
|
bias=True,
|
|
|
|
padding=0
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.norm1 = None
|
|
|
|
if layer_norm:
|
2019-12-01 10:35:44 +08:00
|
|
|
self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
|
2019-11-08 21:00:39 +08:00
|
|
|
self.relu = nn.ReLU(inplace=True)
|
2019-12-01 10:35:44 +08:00
|
|
|
self.fc2 = nn.Conv2d(
|
|
|
|
in_channels // reduction,
|
|
|
|
num_gates,
|
|
|
|
kernel_size=1,
|
|
|
|
bias=True,
|
|
|
|
padding=0
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
if gate_activation == 'sigmoid':
|
|
|
|
self.gate_activation = nn.Sigmoid()
|
|
|
|
elif gate_activation == 'relu':
|
|
|
|
self.gate_activation = nn.ReLU(inplace=True)
|
|
|
|
elif gate_activation == 'linear':
|
|
|
|
self.gate_activation = None
|
|
|
|
else:
|
2019-12-01 10:35:44 +08:00
|
|
|
raise RuntimeError(
|
|
|
|
"Unknown gate activation: {}".format(gate_activation)
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
input = x
|
|
|
|
x = self.global_avgpool(x)
|
|
|
|
x = self.fc1(x)
|
|
|
|
if self.norm1 is not None:
|
|
|
|
x = self.norm1(x)
|
|
|
|
x = self.relu(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
if self.gate_activation is not None:
|
|
|
|
x = self.gate_activation(x)
|
|
|
|
if self.return_gates:
|
|
|
|
return x
|
|
|
|
return input * x
|
|
|
|
|
|
|
|
|
|
|
|
class OSBlock(nn.Module):
|
|
|
|
"""Omni-scale feature learning block."""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
|
|
|
super(OSBlock, self).__init__()
|
|
|
|
assert T >= 1
|
2019-12-01 10:35:44 +08:00
|
|
|
assert out_channels >= reduction and out_channels % reduction == 0
|
2019-11-08 21:00:39 +08:00
|
|
|
mid_channels = out_channels // reduction
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv1 = Conv1x1(in_channels, mid_channels)
|
|
|
|
self.conv2 = nn.ModuleList()
|
2019-12-01 10:35:44 +08:00
|
|
|
for t in range(1, T + 1):
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
2019-12-01 10:35:44 +08:00
|
|
|
self.gate = ChannelGate(mid_channels)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
|
|
|
self.downsample = None
|
|
|
|
if in_channels != out_channels:
|
|
|
|
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
identity = x
|
|
|
|
x1 = self.conv1(x)
|
|
|
|
x2 = 0
|
|
|
|
for conv2_t in self.conv2:
|
|
|
|
x2_t = conv2_t(x1)
|
|
|
|
x2 = x2 + self.gate(x2_t)
|
|
|
|
x3 = self.conv3(x2)
|
|
|
|
if self.downsample is not None:
|
|
|
|
identity = self.downsample(identity)
|
|
|
|
out = x3 + identity
|
|
|
|
return F.relu(out)
|
|
|
|
|
|
|
|
|
|
|
|
class OSBlockINv1(nn.Module):
|
|
|
|
"""Omni-scale feature learning block with instance normalization."""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
|
|
|
super(OSBlockINv1, self).__init__()
|
|
|
|
assert T >= 1
|
2019-12-01 10:35:44 +08:00
|
|
|
assert out_channels >= reduction and out_channels % reduction == 0
|
2019-11-08 21:00:39 +08:00
|
|
|
mid_channels = out_channels // reduction
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv1 = Conv1x1(in_channels, mid_channels)
|
|
|
|
self.conv2 = nn.ModuleList()
|
2019-12-01 10:35:44 +08:00
|
|
|
for t in range(1, T + 1):
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
2019-12-01 10:35:44 +08:00
|
|
|
self.gate = ChannelGate(mid_channels)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
|
|
|
|
self.downsample = None
|
|
|
|
if in_channels != out_channels:
|
|
|
|
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
|
|
|
self.IN = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
identity = x
|
|
|
|
x1 = self.conv1(x)
|
|
|
|
x2 = 0
|
|
|
|
for conv2_t in self.conv2:
|
|
|
|
x2_t = conv2_t(x1)
|
|
|
|
x2 = x2 + self.gate(x2_t)
|
|
|
|
x3 = self.conv3(x2)
|
|
|
|
x3 = self.IN(x3) # IN inside residual
|
|
|
|
if self.downsample is not None:
|
|
|
|
identity = self.downsample(identity)
|
|
|
|
out = x3 + identity
|
|
|
|
return F.relu(out)
|
|
|
|
|
|
|
|
|
|
|
|
class OSBlockINv2(nn.Module):
|
|
|
|
"""Omni-scale feature learning block with instance normalization."""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
|
|
|
super(OSBlockINv2, self).__init__()
|
|
|
|
assert T >= 1
|
2019-12-01 10:35:44 +08:00
|
|
|
assert out_channels >= reduction and out_channels % reduction == 0
|
2019-11-08 21:00:39 +08:00
|
|
|
mid_channels = out_channels // reduction
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv1 = Conv1x1(in_channels, mid_channels)
|
|
|
|
self.conv2 = nn.ModuleList()
|
2019-12-01 10:35:44 +08:00
|
|
|
for t in range(1, T + 1):
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
2019-12-01 10:35:44 +08:00
|
|
|
self.gate = ChannelGate(mid_channels)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
|
|
|
self.downsample = None
|
|
|
|
if in_channels != out_channels:
|
|
|
|
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
|
|
|
self.IN = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
identity = x
|
|
|
|
x1 = self.conv1(x)
|
|
|
|
x2 = 0
|
|
|
|
for conv2_t in self.conv2:
|
|
|
|
x2_t = conv2_t(x1)
|
|
|
|
x2 = x2 + self.gate(x2_t)
|
|
|
|
x3 = self.conv3(x2)
|
|
|
|
if self.downsample is not None:
|
|
|
|
identity = self.downsample(identity)
|
|
|
|
out = x3 + identity
|
|
|
|
out = self.IN(out) # IN outside residual
|
|
|
|
return F.relu(out)
|
|
|
|
|
|
|
|
|
|
|
|
class OSBlockINv3(nn.Module):
|
|
|
|
"""Omni-scale feature learning block with instance normalization."""
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
|
|
|
|
super(OSBlockINv3, self).__init__()
|
|
|
|
assert T >= 1
|
2019-12-01 10:35:44 +08:00
|
|
|
assert out_channels >= reduction and out_channels % reduction == 0
|
2019-11-08 21:00:39 +08:00
|
|
|
mid_channels = out_channels // reduction
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv1 = Conv1x1(in_channels, mid_channels)
|
|
|
|
self.conv2 = nn.ModuleList()
|
2019-12-01 10:35:44 +08:00
|
|
|
for t in range(1, T + 1):
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
|
2019-12-01 10:35:44 +08:00
|
|
|
self.gate = ChannelGate(mid_channels)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
|
|
|
|
self.downsample = None
|
|
|
|
if in_channels != out_channels:
|
|
|
|
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
|
|
|
self.IN_in = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
self.IN_out = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
identity = x
|
|
|
|
x1 = self.conv1(x)
|
|
|
|
x2 = 0
|
|
|
|
for conv2_t in self.conv2:
|
|
|
|
x2_t = conv2_t(x1)
|
|
|
|
x2 = x2 + self.gate(x2_t)
|
|
|
|
x3 = self.conv3(x2)
|
|
|
|
x3 = self.IN_in(x3) # inside residual
|
|
|
|
if self.downsample is not None:
|
|
|
|
identity = self.downsample(identity)
|
|
|
|
out = x3 + identity
|
|
|
|
out = self.IN_out(out) # IN outside residual
|
|
|
|
return F.relu(out)
|
|
|
|
|
|
|
|
|
|
|
|
class NASBlock(nn.Module):
|
|
|
|
"""Neural architecture search layer."""
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, search_space=None):
|
|
|
|
super(NASBlock, self).__init__()
|
|
|
|
self._is_child_graph = False
|
|
|
|
self.search_space = search_space
|
|
|
|
if self.search_space is None:
|
|
|
|
raise ValueError('search_space is None')
|
|
|
|
|
|
|
|
self.os_block = nn.ModuleList()
|
|
|
|
for block in self.search_space:
|
|
|
|
self.os_block += [block(in_channels, out_channels)]
|
|
|
|
self.weights = nn.Parameter(torch.ones(len(self.search_space)))
|
|
|
|
|
|
|
|
def build_child_graph(self):
|
|
|
|
if self._is_child_graph:
|
|
|
|
raise RuntimeError('build_child_graph() can only be called once')
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
idx = self.weights.data.max(dim=0)[1].item()
|
|
|
|
self.os_block = self.os_block[idx]
|
|
|
|
self.weights = None
|
|
|
|
self._is_child_graph = True
|
|
|
|
return self.search_space[idx]
|
|
|
|
|
|
|
|
def forward(self, x, lmda=1.):
|
|
|
|
if self._is_child_graph:
|
|
|
|
return self.os_block(x)
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
uniform = torch.rand_like(self.weights)
|
2019-12-01 10:35:44 +08:00
|
|
|
gumbel = -torch.log(-torch.log(uniform + EPS))
|
2019-11-08 21:00:39 +08:00
|
|
|
nonneg_weights = F.relu(self.weights)
|
|
|
|
logits = torch.log(nonneg_weights + EPS) + gumbel
|
|
|
|
exp = torch.exp(logits / lmda)
|
|
|
|
weights_softmax = exp / (exp.sum() + EPS)
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
output = 0
|
|
|
|
for i, weight in enumerate(weights_softmax):
|
|
|
|
output = output + weight * self.os_block[i](x)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
##########
|
|
|
|
# Network architecture
|
|
|
|
##########
|
|
|
|
class OSNet(nn.Module):
|
|
|
|
"""Omni-Scale Network.
|
|
|
|
|
|
|
|
Reference:
|
|
|
|
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
|
|
|
- Zhou et al. Learning Generalisable Omni-Scale Representations
|
|
|
|
for Person Re-Identification. arXiv preprint, 2019.
|
|
|
|
"""
|
|
|
|
|
2019-12-01 10:35:44 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_classes,
|
|
|
|
blocks,
|
|
|
|
layers,
|
|
|
|
channels,
|
|
|
|
feature_dim=512,
|
|
|
|
loss='softmax',
|
|
|
|
search_space=None,
|
|
|
|
**kwargs
|
|
|
|
):
|
2019-11-08 21:00:39 +08:00
|
|
|
super(OSNet, self).__init__()
|
|
|
|
num_blocks = len(blocks)
|
|
|
|
assert num_blocks == len(layers)
|
|
|
|
assert num_blocks == len(channels) - 1
|
|
|
|
# no matter what loss is specified, the model only returns the ID predictions
|
|
|
|
self.loss = loss
|
|
|
|
self.feature_dim = feature_dim
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
# convolutional backbone
|
|
|
|
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=True)
|
|
|
|
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
2019-12-01 10:35:44 +08:00
|
|
|
self.conv2 = self._make_layer(
|
|
|
|
blocks[0], layers[0], channels[0], channels[1], search_space
|
|
|
|
)
|
|
|
|
self.pool2 = nn.Sequential(
|
|
|
|
Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2)
|
|
|
|
)
|
|
|
|
self.conv3 = self._make_layer(
|
|
|
|
blocks[1], layers[1], channels[1], channels[2], search_space
|
|
|
|
)
|
|
|
|
self.pool3 = nn.Sequential(
|
|
|
|
Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2)
|
|
|
|
)
|
|
|
|
self.conv4 = self._make_layer(
|
|
|
|
blocks[2], layers[2], channels[2], channels[3], search_space
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.conv5 = Conv1x1(channels[3], channels[3])
|
|
|
|
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
# fully connected layer
|
2019-12-01 10:35:44 +08:00
|
|
|
self.fc = self._construct_fc_layer(
|
|
|
|
self.feature_dim, channels[3], dropout_p=None
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
# identity classification layer
|
|
|
|
self.classifier = nn.Linear(self.feature_dim, num_classes)
|
|
|
|
|
2019-12-01 10:35:44 +08:00
|
|
|
def _make_layer(
|
|
|
|
self, block, layer, in_channels, out_channels, search_space
|
|
|
|
):
|
2019-11-08 21:00:39 +08:00
|
|
|
layers = nn.ModuleList()
|
|
|
|
layers += [block(in_channels, out_channels, search_space=search_space)]
|
|
|
|
for i in range(1, layer):
|
2019-12-01 10:35:44 +08:00
|
|
|
layers += [
|
|
|
|
block(out_channels, out_channels, search_space=search_space)
|
|
|
|
]
|
2019-11-08 21:00:39 +08:00
|
|
|
return layers
|
|
|
|
|
|
|
|
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
2019-12-01 10:35:44 +08:00
|
|
|
if fc_dims is None or fc_dims < 0:
|
2019-11-08 21:00:39 +08:00
|
|
|
self.feature_dim = input_dim
|
|
|
|
return None
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
if isinstance(fc_dims, int):
|
|
|
|
fc_dims = [fc_dims]
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
layers = []
|
|
|
|
for dim in fc_dims:
|
|
|
|
layers.append(nn.Linear(input_dim, dim))
|
|
|
|
layers.append(nn.BatchNorm1d(dim, affine=NORM_AFFINE))
|
|
|
|
layers.append(nn.ReLU(inplace=True))
|
|
|
|
if dropout_p is not None:
|
|
|
|
layers.append(nn.Dropout(p=dropout_p))
|
|
|
|
input_dim = dim
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
self.feature_dim = fc_dims[-1]
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
def build_child_graph(self):
|
|
|
|
print('Building child graph')
|
|
|
|
for i, conv in enumerate(self.conv2):
|
|
|
|
block = conv.build_child_graph()
|
2019-12-01 10:35:44 +08:00
|
|
|
print('- conv2-{} Block={}'.format(i + 1, block.__name__))
|
2019-11-08 21:00:39 +08:00
|
|
|
for i, conv in enumerate(self.conv3):
|
|
|
|
block = conv.build_child_graph()
|
2019-12-01 10:35:44 +08:00
|
|
|
print('- conv3-{} Block={}'.format(i + 1, block.__name__))
|
2019-11-08 21:00:39 +08:00
|
|
|
for i, conv in enumerate(self.conv4):
|
|
|
|
block = conv.build_child_graph()
|
2019-12-01 10:35:44 +08:00
|
|
|
print('- conv4-{} Block={}'.format(i + 1, block.__name__))
|
2019-11-08 21:00:39 +08:00
|
|
|
|
|
|
|
def featuremaps(self, x, lmda):
|
|
|
|
x = self.conv1(x)
|
|
|
|
x = self.maxpool(x)
|
|
|
|
for conv in self.conv2:
|
|
|
|
x = conv(x, lmda)
|
|
|
|
x = self.pool2(x)
|
|
|
|
for conv in self.conv3:
|
|
|
|
x = conv(x, lmda)
|
|
|
|
x = self.pool3(x)
|
|
|
|
for conv in self.conv4:
|
|
|
|
x = conv(x, lmda)
|
|
|
|
return self.conv5(x)
|
|
|
|
|
|
|
|
def forward(self, x, lmda=1., return_featuremaps=False):
|
|
|
|
# lmda (float): temperature parameter for concrete distribution
|
|
|
|
x = self.featuremaps(x, lmda)
|
|
|
|
if return_featuremaps:
|
|
|
|
return x
|
|
|
|
v = self.global_avgpool(x)
|
|
|
|
v = v.view(v.size(0), -1)
|
|
|
|
if self.fc is not None:
|
|
|
|
v = self.fc(v)
|
|
|
|
if not self.training:
|
|
|
|
return v
|
|
|
|
return self.classifier(v)
|
|
|
|
|
|
|
|
|
|
|
|
##########
|
|
|
|
# Instantiation
|
|
|
|
##########
|
|
|
|
def osnet_nas(num_classes=1000, loss='softmax', **kwargs):
|
|
|
|
# standard size (width x1.0)
|
|
|
|
return OSNet(
|
|
|
|
num_classes,
|
|
|
|
blocks=[NASBlock, NASBlock, NASBlock],
|
|
|
|
layers=[2, 2, 2],
|
|
|
|
channels=[64, 256, 384, 512],
|
|
|
|
loss=loss,
|
|
|
|
search_space=[OSBlock, OSBlockINv1, OSBlockINv2, OSBlockINv3],
|
|
|
|
**kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2019-12-01 10:35:44 +08:00
|
|
|
__NAS_models = {'osnet_nas': osnet_nas}
|
2019-11-08 21:00:39 +08:00
|
|
|
|
|
|
|
|
|
|
|
def build_model(name, num_classes=100):
|
|
|
|
avai_models = list(__NAS_models.keys())
|
|
|
|
if name not in avai_models:
|
2019-12-01 10:35:44 +08:00
|
|
|
raise KeyError(
|
|
|
|
'Unknown model: {}. Must be one of {}'.format(name, avai_models)
|
|
|
|
)
|
2019-11-08 21:00:39 +08:00
|
|
|
return __NAS_models[name](num_classes=num_classes)
|