mirror of https://github.com/JDAI-CV/fast-reid.git
parent
84c733fa85
commit
5d4758125d
|
@ -30,7 +30,42 @@ model_urls = {
|
|||
152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
__all__ = ['ResNet', 'Bottleneck']
|
||||
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck']
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, with_se=False,
|
||||
stride=1, downsample=None, reduction=16):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = get_norm(bn_norm, planes, num_splits)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, planes, num_splits)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if with_se: self.se = SELayer(planes, reduction)
|
||||
else: self.se = nn.Identity()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
|
@ -40,20 +75,16 @@ class Bottleneck(nn.Module):
|
|||
stride=1, downsample=None, reduction=16):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(planes, bn_norm, num_splits)
|
||||
else:
|
||||
self.bn1 = get_norm(bn_norm, planes, num_splits)
|
||||
if with_ibn: self.bn1 = IBN(planes, bn_norm, num_splits)
|
||||
else: self.bn1 = get_norm(bn_norm, planes, num_splits)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, planes, num_splits)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if with_se:
|
||||
self.se = SELayer(planes * 4, reduction)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
if with_se: self.se = SELayer(planes * 4, reduction)
|
||||
else: self.se = nn.Identity()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
@ -83,18 +114,17 @@ class Bottleneck(nn.Module):
|
|||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block, layers, non_layers):
|
||||
scale = 64
|
||||
self.inplanes = scale
|
||||
self.inplanes = 64
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = get_norm(bn_norm, 64, num_splits)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, scale, layers[0], 1, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer2 = self._make_layer(block, scale * 2, layers[1], 2, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer3 = self._make_layer(block, scale * 4, layers[2], 2, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer4 = self._make_layer(block, scale * 8, layers[3], last_stride, bn_norm, num_splits, with_se=with_se)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_se=with_se)
|
||||
|
||||
self.random_init()
|
||||
|
||||
|
@ -213,9 +243,10 @@ def build_resnet_backbone(cfg):
|
|||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
|
||||
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 9, 0]}[depth]
|
||||
model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, Bottleneck,
|
||||
num_blocks_per_stage = {34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
|
||||
nl_layers_per_stage = {34: [3, 4, 6, 3], 50: [0, 2, 3, 0], 101: [0, 2, 9, 0]}[depth]
|
||||
block = {34: BasicBlock, 50: Bottleneck}[depth]
|
||||
model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block,
|
||||
num_blocks_per_stage, nl_layers_per_stage)
|
||||
if pretrain:
|
||||
if not with_ibn:
|
||||
|
|
|
@ -37,7 +37,7 @@ class BNneckHead(nn.Module):
|
|||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
bn_feat = Flatten()(bn_feat)
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
# Evaluation
|
||||
if not self.training:
|
||||
return bn_feat
|
||||
|
@ -48,7 +48,7 @@ class BNneckHead(nn.Module):
|
|||
pred_class_logits = self.classifier(bn_feat, targets)
|
||||
|
||||
if self.neck_feat == "before":
|
||||
feat = Flatten()(global_feat)
|
||||
feat = global_feat[..., 0, 0]
|
||||
elif self.neck_feat == "after":
|
||||
feat = bn_feat
|
||||
else:
|
||||
|
|
|
@ -29,7 +29,7 @@ class LinearHead(nn.Module):
|
|||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
global_feat = Flatten()(global_feat)
|
||||
global_feat = global_feat[..., 0, 0]
|
||||
if not self.training:
|
||||
return global_feat
|
||||
# training
|
||||
|
|
|
@ -45,7 +45,7 @@ class ReductionHead(nn.Module):
|
|||
global_feat = self.pool_layer(features)
|
||||
global_feat = self.bottleneck(global_feat)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
bn_feat = Flatten()(bn_feat)
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
# Evaluation
|
||||
if not self.training:
|
||||
return bn_feat
|
||||
|
@ -56,7 +56,7 @@ class ReductionHead(nn.Module):
|
|||
pred_class_logits = self.classifier(bn_feat, targets)
|
||||
|
||||
if self.neck_feat == "before":
|
||||
feat = Flatten()(global_feat)
|
||||
feat = global_feat[..., 0, 0]
|
||||
elif self.neck_feat == "after":
|
||||
feat = bn_feat
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue