support ResNet34 backbone

Summary: add BasicBlock to support ResNet34
pull/63/head
liaoxingyu 2020-05-26 13:18:09 +08:00
parent 84c733fa85
commit 5d4758125d
4 changed files with 54 additions and 23 deletions

View File

@ -30,7 +30,42 @@ model_urls = {
152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
__all__ = ['ResNet', 'Bottleneck']
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck']
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, with_se=False,
stride=1, downsample=None, reduction=16):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = get_norm(bn_norm, planes, num_splits)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = get_norm(bn_norm, planes, num_splits)
self.relu = nn.ReLU(inplace=True)
if with_se: self.se = SELayer(planes, reduction)
else: self.se = nn.Identity()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
@ -40,20 +75,16 @@ class Bottleneck(nn.Module):
stride=1, downsample=None, reduction=16):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
if with_ibn:
self.bn1 = IBN(planes, bn_norm, num_splits)
else:
self.bn1 = get_norm(bn_norm, planes, num_splits)
if with_ibn: self.bn1 = IBN(planes, bn_norm, num_splits)
else: self.bn1 = get_norm(bn_norm, planes, num_splits)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = get_norm(bn_norm, planes, num_splits)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
self.relu = nn.ReLU(inplace=True)
if with_se:
self.se = SELayer(planes * 4, reduction)
else:
self.se = nn.Identity()
if with_se: self.se = SELayer(planes * 4, reduction)
else: self.se = nn.Identity()
self.downsample = downsample
self.stride = stride
@ -83,18 +114,17 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block, layers, non_layers):
scale = 64
self.inplanes = scale
self.inplanes = 64
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = get_norm(bn_norm, 64, num_splits)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, scale, layers[0], 1, bn_norm, num_splits, with_ibn, with_se)
self.layer2 = self._make_layer(block, scale * 2, layers[1], 2, bn_norm, num_splits, with_ibn, with_se)
self.layer3 = self._make_layer(block, scale * 4, layers[2], 2, bn_norm, num_splits, with_ibn, with_se)
self.layer4 = self._make_layer(block, scale * 8, layers[3], last_stride, bn_norm, num_splits, with_se=with_se)
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn, with_se)
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn, with_se)
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn, with_se)
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_se=with_se)
self.random_init()
@ -213,9 +243,10 @@ def build_resnet_backbone(cfg):
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 9, 0]}[depth]
model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, Bottleneck,
num_blocks_per_stage = {34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
nl_layers_per_stage = {34: [3, 4, 6, 3], 50: [0, 2, 3, 0], 101: [0, 2, 9, 0]}[depth]
block = {34: BasicBlock, 50: Bottleneck}[depth]
model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block,
num_blocks_per_stage, nl_layers_per_stage)
if pretrain:
if not with_ibn:

View File

@ -37,7 +37,7 @@ class BNneckHead(nn.Module):
"""
global_feat = self.pool_layer(features)
bn_feat = self.bnneck(global_feat)
bn_feat = Flatten()(bn_feat)
bn_feat = bn_feat[..., 0, 0]
# Evaluation
if not self.training:
return bn_feat
@ -48,7 +48,7 @@ class BNneckHead(nn.Module):
pred_class_logits = self.classifier(bn_feat, targets)
if self.neck_feat == "before":
feat = Flatten()(global_feat)
feat = global_feat[..., 0, 0]
elif self.neck_feat == "after":
feat = bn_feat
else:

View File

@ -29,7 +29,7 @@ class LinearHead(nn.Module):
See :class:`ReIDHeads.forward`.
"""
global_feat = self.pool_layer(features)
global_feat = Flatten()(global_feat)
global_feat = global_feat[..., 0, 0]
if not self.training:
return global_feat
# training

View File

@ -45,7 +45,7 @@ class ReductionHead(nn.Module):
global_feat = self.pool_layer(features)
global_feat = self.bottleneck(global_feat)
bn_feat = self.bnneck(global_feat)
bn_feat = Flatten()(bn_feat)
bn_feat = bn_feat[..., 0, 0]
# Evaluation
if not self.training:
return bn_feat
@ -56,7 +56,7 @@ class ReductionHead(nn.Module):
pred_class_logits = self.classifier(bn_feat, targets)
if self.neck_feat == "before":
feat = Flatten()(global_feat)
feat = global_feat[..., 0, 0]
elif self.neck_feat == "after":
feat = bn_feat
else: