mirror of https://github.com/JDAI-CV/fast-reid.git
remove `num_splits` in batchnorm
Summary: `num_splits` works for GhostBN, but it's very uncommonpull/299/head
parent
3c48bf78c1
commit
1b84348619
fastreid
modeling/backbones
|
@ -22,7 +22,7 @@ __all__ = [
|
|||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
bias_init=0.0, **kwargs):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
|
||||
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
|
||||
|
@ -34,20 +34,20 @@ class SyncBatchNorm(nn.SyncBatchNorm):
|
|||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
if weight_init is not None: self.weight.data.fill_(weight_init)
|
||||
if bias_init is not None: self.bias.data.fill_(bias_init)
|
||||
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
|
||||
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
|
||||
self.weight.requires_grad_(not weight_freeze)
|
||||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
||||
class IBN(nn.Module):
|
||||
def __init__(self, planes, bn_norm, num_splits):
|
||||
def __init__(self, planes, bn_norm, **kwargs):
|
||||
super(IBN, self).__init__()
|
||||
half1 = int(planes / 2)
|
||||
self.half = half1
|
||||
half2 = planes - half1
|
||||
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
||||
self.BN = get_norm(bn_norm, half2, num_splits)
|
||||
self.BN = get_norm(bn_norm, half2, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
split = torch.split(x, self.half, 1)
|
||||
|
@ -100,8 +100,8 @@ class FrozenBatchNorm(BatchNorm):
|
|||
|
||||
_version = 3
|
||||
|
||||
def __init__(self, num_features, eps=1e-5):
|
||||
super().__init__(num_features, weight_freeze=True, bias_freeze=True)
|
||||
def __init__(self, num_features, eps=1e-5, **kwargs):
|
||||
super().__init__(num_features, weight_freeze=True, bias_freeze=True, **kwargs)
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
|
||||
|
@ -184,10 +184,14 @@ class FrozenBatchNorm(BatchNorm):
|
|||
return res
|
||||
|
||||
|
||||
def get_norm(norm, out_channels, num_splits=1, **kwargs):
|
||||
def get_norm(norm, out_channels, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
norm (str or callable):
|
||||
norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
|
||||
or a callable that thakes a channel number and returns
|
||||
the normalization layer as a nn.Module
|
||||
out_channels: number of channels for normalization layer
|
||||
|
||||
Returns:
|
||||
nn.Module or None: the normalization layer
|
||||
"""
|
||||
|
@ -195,10 +199,10 @@ def get_norm(norm, out_channels, num_splits=1, **kwargs):
|
|||
if len(norm) == 0:
|
||||
return None
|
||||
norm = {
|
||||
"BN": BatchNorm(out_channels, **kwargs),
|
||||
"GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs),
|
||||
"FrozenBN": FrozenBatchNorm(out_channels),
|
||||
"GN": nn.GroupNorm(32, out_channels),
|
||||
"syncBN": SyncBatchNorm(out_channels, **kwargs),
|
||||
"BN": BatchNorm,
|
||||
"GhostBN": GhostBatchNorm,
|
||||
"FrozenBN": FrozenBatchNorm,
|
||||
"GN": lambda channels, **args: nn.GroupNorm(32, channels),
|
||||
"syncBN": SyncBatchNorm,
|
||||
}[norm]
|
||||
return norm
|
||||
return norm(out_channels, **kwargs)
|
||||
|
|
|
@ -7,7 +7,7 @@ from .batch_norm import get_norm
|
|||
|
||||
|
||||
class Non_local(nn.Module):
|
||||
def __init__(self, in_channels, bn_norm, num_splits, reduc_ratio=2):
|
||||
def __init__(self, in_channels, bn_norm, reduc_ratio=2):
|
||||
super(Non_local, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
@ -19,7 +19,7 @@ class Non_local(nn.Module):
|
|||
self.W = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
||||
kernel_size=1, stride=1, padding=0),
|
||||
get_norm(bn_norm, self.in_channels, num_splits),
|
||||
get_norm(bn_norm, self.in_channels),
|
||||
)
|
||||
nn.init.constant_(self.W[1].weight, 0.0)
|
||||
nn.init.constant_(self.W[1].bias, 0.0)
|
||||
|
@ -31,10 +31,10 @@ class Non_local(nn.Module):
|
|||
kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
"""
|
||||
:param x: (b, t, h, w)
|
||||
:return x: (b, t, h, w)
|
||||
'''
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
|
|
@ -43,7 +43,6 @@ class ConvLayer(nn.Module):
|
|||
out_channels,
|
||||
kernel_size,
|
||||
bn_norm,
|
||||
num_splits,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
|
@ -62,7 +61,7 @@ class ConvLayer(nn.Module):
|
|||
if IN:
|
||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
else:
|
||||
self.bn = get_norm(bn_norm, out_channels, num_splits)
|
||||
self.bn = get_norm(bn_norm, out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -75,7 +74,7 @@ class ConvLayer(nn.Module):
|
|||
class Conv1x1(nn.Module):
|
||||
"""1x1 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1, groups=1):
|
||||
def __init__(self, in_channels, out_channels, bn_norm, stride=1, groups=1):
|
||||
super(Conv1x1, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
|
@ -86,7 +85,7 @@ class Conv1x1(nn.Module):
|
|||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
self.bn = get_norm(bn_norm, out_channels, num_splits)
|
||||
self.bn = get_norm(bn_norm, out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -99,12 +98,12 @@ class Conv1x1(nn.Module):
|
|||
class Conv1x1Linear(nn.Module):
|
||||
"""1x1 convolution + bn (w/o non-linearity)."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1):
|
||||
def __init__(self, in_channels, out_channels, bn_norm, stride=1):
|
||||
super(Conv1x1Linear, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=stride, padding=0, bias=False
|
||||
)
|
||||
self.bn = get_norm(bn_norm, out_channels, num_splits)
|
||||
self.bn = get_norm(bn_norm, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
|
@ -115,7 +114,7 @@ class Conv1x1Linear(nn.Module):
|
|||
class Conv3x3(nn.Module):
|
||||
"""3x3 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1, groups=1):
|
||||
def __init__(self, in_channels, out_channels, bn_norm, stride=1, groups=1):
|
||||
super(Conv3x3, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
|
@ -126,7 +125,7 @@ class Conv3x3(nn.Module):
|
|||
bias=False,
|
||||
groups=groups
|
||||
)
|
||||
self.bn = get_norm(bn_norm, out_channels, num_splits)
|
||||
self.bn = get_norm(bn_norm, out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -141,7 +140,7 @@ class LightConv3x3(nn.Module):
|
|||
1x1 (linear) + dw 3x3 (nonlinear).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bn_norm, num_splits):
|
||||
def __init__(self, in_channels, out_channels, bn_norm):
|
||||
super(LightConv3x3, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=1, padding=0, bias=False
|
||||
|
@ -155,7 +154,7 @@ class LightConv3x3(nn.Module):
|
|||
bias=False,
|
||||
groups=out_channels
|
||||
)
|
||||
self.bn = get_norm(bn_norm, out_channels, num_splits)
|
||||
self.bn = get_norm(bn_norm, out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -235,35 +234,34 @@ class OSBlock(nn.Module):
|
|||
in_channels,
|
||||
out_channels,
|
||||
bn_norm,
|
||||
num_splits,
|
||||
IN=False,
|
||||
bottleneck_reduction=4,
|
||||
**kwargs
|
||||
):
|
||||
super(OSBlock, self).__init__()
|
||||
mid_channels = out_channels // bottleneck_reduction
|
||||
self.conv1 = Conv1x1(in_channels, mid_channels, bn_norm, num_splits)
|
||||
self.conv2a = LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits)
|
||||
self.conv1 = Conv1x1(in_channels, mid_channels, bn_norm)
|
||||
self.conv2a = LightConv3x3(mid_channels, mid_channels, bn_norm)
|
||||
self.conv2b = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
)
|
||||
self.conv2c = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
)
|
||||
self.conv2d = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
LightConv3x3(mid_channels, mid_channels, bn_norm),
|
||||
)
|
||||
self.gate = ChannelGate(mid_channels)
|
||||
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn_norm, num_splits)
|
||||
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn_norm)
|
||||
self.downsample = None
|
||||
if in_channels != out_channels:
|
||||
self.downsample = Conv1x1Linear(in_channels, out_channels, bn_norm, num_splits)
|
||||
self.downsample = Conv1x1Linear(in_channels, out_channels, bn_norm)
|
||||
self.IN = None
|
||||
if IN: self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
self.relu = nn.ReLU(True)
|
||||
|
@ -303,7 +301,6 @@ class OSNet(nn.Module):
|
|||
layers,
|
||||
channels,
|
||||
bn_norm,
|
||||
num_splits,
|
||||
IN=False,
|
||||
**kwargs
|
||||
):
|
||||
|
@ -313,7 +310,7 @@ class OSNet(nn.Module):
|
|||
assert num_blocks == len(channels) - 1
|
||||
|
||||
# convolutional backbone
|
||||
self.conv1 = ConvLayer(3, channels[0], 7, bn_norm, num_splits, stride=2, padding=3, IN=IN)
|
||||
self.conv1 = ConvLayer(3, channels[0], 7, bn_norm, stride=2, padding=3, IN=IN)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.conv2 = self._make_layer(
|
||||
blocks[0],
|
||||
|
@ -321,7 +318,6 @@ class OSNet(nn.Module):
|
|||
channels[0],
|
||||
channels[1],
|
||||
bn_norm,
|
||||
num_splits,
|
||||
reduce_spatial_size=True,
|
||||
IN=IN
|
||||
)
|
||||
|
@ -331,7 +327,6 @@ class OSNet(nn.Module):
|
|||
channels[1],
|
||||
channels[2],
|
||||
bn_norm,
|
||||
num_splits,
|
||||
reduce_spatial_size=True
|
||||
)
|
||||
self.conv4 = self._make_layer(
|
||||
|
@ -340,10 +335,9 @@ class OSNet(nn.Module):
|
|||
channels[2],
|
||||
channels[3],
|
||||
bn_norm,
|
||||
num_splits,
|
||||
reduce_spatial_size=False
|
||||
)
|
||||
self.conv5 = Conv1x1(channels[3], channels[3], bn_norm, num_splits)
|
||||
self.conv5 = Conv1x1(channels[3], channels[3], bn_norm)
|
||||
|
||||
self._init_params()
|
||||
|
||||
|
@ -354,20 +348,19 @@ class OSNet(nn.Module):
|
|||
in_channels,
|
||||
out_channels,
|
||||
bn_norm,
|
||||
num_splits,
|
||||
reduce_spatial_size,
|
||||
IN=False
|
||||
):
|
||||
layers = []
|
||||
|
||||
layers.append(block(in_channels, out_channels, bn_norm, num_splits, IN=IN))
|
||||
layers.append(block(in_channels, out_channels, bn_norm, IN=IN))
|
||||
for i in range(1, layer):
|
||||
layers.append(block(out_channels, out_channels, bn_norm, num_splits, IN=IN))
|
||||
layers.append(block(out_channels, out_channels, bn_norm, IN=IN))
|
||||
|
||||
if reduce_spatial_size:
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
Conv1x1(out_channels, out_channels, bn_norm, num_splits),
|
||||
Conv1x1(out_channels, out_channels, bn_norm),
|
||||
nn.AvgPool2d(2, stride=2),
|
||||
)
|
||||
)
|
||||
|
@ -498,12 +491,11 @@ def build_osnet_backbone(cfg):
|
|||
"""
|
||||
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
# fmt: on
|
||||
|
||||
num_blocks_per_stage = [2, 2, 2]
|
||||
|
@ -513,7 +505,7 @@ def build_osnet_backbone(cfg):
|
|||
"x0_5": [32, 128, 192, 256],
|
||||
"x0_25": [16, 64, 96, 128]}[depth]
|
||||
model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage,
|
||||
bn_norm, num_splits, IN=with_ibn)
|
||||
bn_norm, IN=with_ibn)
|
||||
|
||||
if pretrain:
|
||||
# Load pretrain path if specifically
|
||||
|
|
|
@ -46,7 +46,7 @@ class Bottleneck(nn.Module):
|
|||
# pylint: disable=unused-argument
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, stride=1, downsample=None,
|
||||
def __init__(self, inplanes, planes, bn_norm, with_ibn=False, stride=1, downsample=None,
|
||||
radix=1, cardinality=1, bottleneck_width=64,
|
||||
avd=False, avd_first=False, dilation=1, is_first=False,
|
||||
rectified_conv=False, rectify_avg=False,
|
||||
|
@ -55,9 +55,9 @@ class Bottleneck(nn.Module):
|
|||
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
|
||||
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(group_width, bn_norm, num_splits)
|
||||
self.bn1 = IBN(group_width, bn_norm)
|
||||
else:
|
||||
self.bn1 = get_norm(bn_norm, group_width, num_splits)
|
||||
self.bn1 = get_norm(bn_norm, group_width)
|
||||
self.dropblock_prob = dropblock_prob
|
||||
self.radix = radix
|
||||
self.avd = avd and (stride > 1 or is_first)
|
||||
|
@ -74,7 +74,7 @@ class Bottleneck(nn.Module):
|
|||
dilation=dilation, groups=cardinality, bias=False,
|
||||
radix=radix, rectify=rectified_conv,
|
||||
rectify_avg=rectify_avg,
|
||||
norm_layer=bn_norm, num_splits=num_splits,
|
||||
norm_layer=bn_norm,
|
||||
dropblock_prob=dropblock_prob)
|
||||
elif rectified_conv:
|
||||
from rfconv import RFConv2d
|
||||
|
@ -83,17 +83,17 @@ class Bottleneck(nn.Module):
|
|||
padding=dilation, dilation=dilation,
|
||||
groups=cardinality, bias=False,
|
||||
average_mode=rectify_avg)
|
||||
self.bn2 = get_norm(bn_norm, group_width, num_splits)
|
||||
self.bn2 = get_norm(bn_norm, group_width)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride,
|
||||
padding=dilation, dilation=dilation,
|
||||
groups=cardinality, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, group_width, num_splits)
|
||||
self.bn2 = get_norm(bn_norm, group_width)
|
||||
|
||||
self.conv3 = nn.Conv2d(
|
||||
group_width, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
|
||||
self.bn3 = get_norm(bn_norm, planes * 4)
|
||||
|
||||
if last_gamma:
|
||||
from torch.nn.init import zeros_
|
||||
|
@ -161,7 +161,7 @@ class ResNest(nn.Module):
|
|||
"""
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers, radix=1,
|
||||
def __init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers, radix=1,
|
||||
groups=1,
|
||||
bottleneck_width=64,
|
||||
dilated=False, dilation=1,
|
||||
|
@ -193,35 +193,35 @@ class ResNest(nn.Module):
|
|||
if deep_stem:
|
||||
self.conv1 = nn.Sequential(
|
||||
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
|
||||
get_norm(bn_norm, stem_width, num_splits),
|
||||
get_norm(bn_norm, stem_width),
|
||||
nn.ReLU(inplace=True),
|
||||
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
|
||||
get_norm(bn_norm, stem_width, num_splits),
|
||||
get_norm(bn_norm, stem_width),
|
||||
nn.ReLU(inplace=True),
|
||||
conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
|
||||
)
|
||||
else:
|
||||
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False, **conv_kwargs)
|
||||
self.bn1 = get_norm(bn_norm, self.inplanes, num_splits)
|
||||
self.bn1 = get_norm(bn_norm, self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn, is_first=False)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn=with_ibn, is_first=False)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn=with_ibn)
|
||||
if dilated or dilation == 4:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 1, bn_norm, num_splits, with_ibn=with_ibn,
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 1, bn_norm, with_ibn=with_ibn,
|
||||
dilation=2, dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn,
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, with_ibn=with_ibn,
|
||||
dilation=4, dropblock_prob=dropblock_prob)
|
||||
elif dilation == 2:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn,
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn,
|
||||
dilation=1, dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn,
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, with_ibn=with_ibn,
|
||||
dilation=2, dropblock_prob=dropblock_prob)
|
||||
else:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn,
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn,
|
||||
dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn,
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_ibn=with_ibn,
|
||||
dropblock_prob=dropblock_prob)
|
||||
|
||||
for m in self.modules():
|
||||
|
@ -232,12 +232,12 @@ class ResNest(nn.Module):
|
|||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
if with_nl:
|
||||
self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
|
||||
else:
|
||||
self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
|
||||
# fmt: off
|
||||
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
|
||||
else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
|
||||
# fmt: on
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False,
|
||||
def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False,
|
||||
dilation=1, dropblock_prob=0.0, is_first=True):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
|
@ -254,12 +254,12 @@ class ResNest(nn.Module):
|
|||
else:
|
||||
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False))
|
||||
down_layers.append(get_norm(bn_norm, planes * block.expansion, num_splits))
|
||||
down_layers.append(get_norm(bn_norm, planes * block.expansion))
|
||||
downsample = nn.Sequential(*down_layers)
|
||||
|
||||
layers = []
|
||||
if dilation == 1 or dilation == 2:
|
||||
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample,
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn, stride, downsample=downsample,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
|
@ -268,7 +268,7 @@ class ResNest(nn.Module):
|
|||
dropblock_prob=dropblock_prob,
|
||||
last_gamma=self.last_gamma))
|
||||
elif dilation == 4:
|
||||
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample,
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn, stride, downsample=downsample,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
|
@ -281,7 +281,7 @@ class ResNest(nn.Module):
|
|||
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn,
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
|
@ -292,18 +292,18 @@ class ResNest(nn.Module):
|
|||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits):
|
||||
def _build_nonlocal(self, layers, non_layers, bn_norm):
|
||||
self.NL_1 = nn.ModuleList(
|
||||
[Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])])
|
||||
[Non_local(256, bn_norm) for _ in range(non_layers[0])])
|
||||
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
|
||||
self.NL_2 = nn.ModuleList(
|
||||
[Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])])
|
||||
[Non_local(512, bn_norm) for _ in range(non_layers[1])])
|
||||
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
|
||||
self.NL_3 = nn.ModuleList(
|
||||
[Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])])
|
||||
[Non_local(1024, bn_norm) for _ in range(non_layers[2])])
|
||||
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
|
||||
self.NL_4 = nn.ModuleList(
|
||||
[Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])])
|
||||
[Non_local(2048, bn_norm) for _ in range(non_layers[3])])
|
||||
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -364,21 +364,38 @@ def build_resnest_backbone(cfg):
|
|||
"""
|
||||
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
# fmt: on
|
||||
|
||||
num_blocks_per_stage = {"50x": [3, 4, 6, 3], "101x": [3, 4, 23, 3], "200x": [3, 24, 36, 3],
|
||||
"269x": [3, 30, 48, 8]}[depth]
|
||||
nl_layers_per_stage = {"50x": [0, 2, 3, 0], "101x": [0, 2, 3, 0], "200x": [0, 2, 3, 0], "269x": [0, 2, 3, 0]}[depth]
|
||||
stem_width = {"50x": 32, "101x": 64, "200x": 64, "269x": 64}[depth]
|
||||
model = ResNest(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck, num_blocks_per_stage,
|
||||
num_blocks_per_stage = {
|
||||
"50x": [3, 4, 6, 3],
|
||||
"101x": [3, 4, 23, 3],
|
||||
"200x": [3, 24, 36, 3],
|
||||
"269x": [3, 30, 48, 8],
|
||||
}[depth]
|
||||
|
||||
nl_layers_per_stage = {
|
||||
"50x": [0, 2, 3, 0],
|
||||
"101x": [0, 2, 3, 0],
|
||||
"200x": [0, 2, 3, 0],
|
||||
"269x": [0, 2, 3, 0],
|
||||
}[depth]
|
||||
|
||||
stem_width = {
|
||||
"50x": 32,
|
||||
"101x": 64,
|
||||
"200x": 64,
|
||||
"269x": 64,
|
||||
}[depth]
|
||||
|
||||
model = ResNest(last_stride, bn_norm, with_ibn, with_nl, Bottleneck, num_blocks_per_stage,
|
||||
nl_layers_per_stage, radix=2, groups=1, bottleneck_width=64,
|
||||
deep_stem=True, stem_width=stem_width, avg_down=True,
|
||||
avd=True, avd_first=False)
|
||||
|
|
|
@ -38,16 +38,16 @@ model_urls = {
|
|||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, with_se=False,
|
||||
def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
|
||||
stride=1, downsample=None, reduction=16):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(planes, bn_norm, num_splits)
|
||||
self.bn1 = IBN(planes, bn_norm)
|
||||
else:
|
||||
self.bn1 = get_norm(bn_norm, planes, num_splits)
|
||||
self.bn1 = get_norm(bn_norm, planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, planes, num_splits)
|
||||
self.bn2 = get_norm(bn_norm, planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if with_se:
|
||||
self.se = SELayer(planes, reduction)
|
||||
|
@ -78,19 +78,19 @@ class BasicBlock(nn.Module):
|
|||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, with_se=False,
|
||||
def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False,
|
||||
stride=1, downsample=None, reduction=16):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(planes, bn_norm, num_splits)
|
||||
self.bn1 = IBN(planes, bn_norm)
|
||||
else:
|
||||
self.bn1 = get_norm(bn_norm, planes, num_splits)
|
||||
self.bn1 = get_norm(bn_norm, planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, planes, num_splits)
|
||||
self.bn2 = get_norm(bn_norm, planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = get_norm(bn_norm, planes * self.expansion, num_splits)
|
||||
self.bn3 = get_norm(bn_norm, planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if with_se:
|
||||
self.se = SELayer(planes * self.expansion, reduction)
|
||||
|
@ -124,56 +124,56 @@ class Bottleneck(nn.Module):
|
|||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block, layers, non_layers):
|
||||
def __init__(self, last_stride, bn_norm, with_ibn, with_se, with_nl, block, layers, non_layers):
|
||||
self.inplanes = 64
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = get_norm(bn_norm, 64, num_splits)
|
||||
self.bn1 = get_norm(bn_norm, 64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn, with_se)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_se=with_se)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn, with_se)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn, with_se)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn, with_se)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_se=with_se)
|
||||
|
||||
self.random_init()
|
||||
|
||||
# fmt: off
|
||||
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
|
||||
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
|
||||
else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
|
||||
# fmt: on
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False, with_se=False):
|
||||
def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, with_se=False):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
get_norm(bn_norm, planes * block.expansion, num_splits),
|
||||
get_norm(bn_norm, planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, with_se, stride, downsample))
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, with_se))
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits):
|
||||
def _build_nonlocal(self, layers, non_layers, bn_norm):
|
||||
self.NL_1 = nn.ModuleList(
|
||||
[Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])])
|
||||
[Non_local(256, bn_norm) for _ in range(non_layers[0])])
|
||||
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
|
||||
self.NL_2 = nn.ModuleList(
|
||||
[Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])])
|
||||
[Non_local(512, bn_norm) for _ in range(non_layers[1])])
|
||||
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
|
||||
self.NL_3 = nn.ModuleList(
|
||||
[Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])])
|
||||
[Non_local(1024, bn_norm) for _ in range(non_layers[2])])
|
||||
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
|
||||
self.NL_4 = nn.ModuleList(
|
||||
[Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])])
|
||||
[Non_local(2048, bn_norm) for _ in range(non_layers[3])])
|
||||
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -298,7 +298,6 @@ def build_resnet_backbone(cfg):
|
|||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
|
@ -326,7 +325,7 @@ def build_resnet_backbone(cfg):
|
|||
'101x': Bottleneck
|
||||
}[depth]
|
||||
|
||||
model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block,
|
||||
model = ResNet(last_stride, bn_norm, with_ibn, with_se, with_nl, block,
|
||||
num_blocks_per_stage, nl_layers_per_stage)
|
||||
if pretrain:
|
||||
# Load pretrain path if specifically
|
||||
|
|
|
@ -30,7 +30,7 @@ class Bottleneck(nn.Module):
|
|||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn, baseWidth, cardinality, stride=1,
|
||||
def __init__(self, inplanes, planes, bn_norm, with_ibn, baseWidth, cardinality, stride=1,
|
||||
downsample=None):
|
||||
""" Constructor
|
||||
Args:
|
||||
|
@ -46,13 +46,13 @@ class Bottleneck(nn.Module):
|
|||
C = cardinality
|
||||
self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(D * C, bn_norm, num_splits)
|
||||
self.bn1 = IBN(D * C, bn_norm)
|
||||
else:
|
||||
self.bn1 = get_norm(bn_norm, D * C, num_splits)
|
||||
self.bn1 = get_norm(bn_norm, D * C)
|
||||
self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, D * C, num_splits)
|
||||
self.bn2 = get_norm(bn_norm, D * C)
|
||||
self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
|
||||
self.bn3 = get_norm(bn_norm, planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.downsample = downsample
|
||||
|
@ -86,7 +86,7 @@ class ResNeXt(nn.Module):
|
|||
https://arxiv.org/pdf/1611.05431.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers,
|
||||
def __init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers,
|
||||
baseWidth=4, cardinality=32):
|
||||
""" Constructor
|
||||
Args:
|
||||
|
@ -102,22 +102,22 @@ class ResNeXt(nn.Module):
|
|||
self.output_size = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
|
||||
self.bn1 = get_norm(bn_norm, 64, num_splits)
|
||||
self.bn1 = get_norm(bn_norm, 64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn=with_ibn)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn=with_ibn)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_ibn=with_ibn)
|
||||
|
||||
self.random_init()
|
||||
|
||||
# fmt: off
|
||||
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
|
||||
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
|
||||
else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
|
||||
# fmt: on
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', num_splits=1, with_ibn=False):
|
||||
def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', with_ibn=False):
|
||||
""" Stack n bottleneck modules where n is inferred from the depth of the network.
|
||||
Args:
|
||||
block: block type used to construct ResNext
|
||||
|
@ -131,33 +131,31 @@ class ResNeXt(nn.Module):
|
|||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
get_norm(bn_norm, planes * block.expansion, num_splits),
|
||||
get_norm(bn_norm, planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
if planes == 512:
|
||||
with_ibn = False
|
||||
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn,
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn,
|
||||
self.baseWidth, self.cardinality, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes, planes, bn_norm, num_splits, with_ibn, self.baseWidth, self.cardinality, 1, None))
|
||||
block(self.inplanes, planes, bn_norm, with_ibn, self.baseWidth, self.cardinality, 1, None))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits):
|
||||
def _build_nonlocal(self, layers, non_layers, bn_norm):
|
||||
self.NL_1 = nn.ModuleList(
|
||||
[Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])])
|
||||
[Non_local(256, bn_norm) for _ in range(non_layers[0])])
|
||||
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
|
||||
self.NL_2 = nn.ModuleList(
|
||||
[Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])])
|
||||
[Non_local(512, bn_norm) for _ in range(non_layers[1])])
|
||||
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
|
||||
self.NL_3 = nn.ModuleList(
|
||||
[Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])])
|
||||
[Non_local(1024, bn_norm) for _ in range(non_layers[2])])
|
||||
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
|
||||
self.NL_4 = nn.ModuleList(
|
||||
[Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])])
|
||||
[Non_local(2048, bn_norm) for _ in range(non_layers[3])])
|
||||
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -285,7 +283,6 @@ def build_resnext_backbone(cfg):
|
|||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
|
@ -298,7 +295,7 @@ def build_resnext_backbone(cfg):
|
|||
nl_layers_per_stage = {
|
||||
'50x': [0, 2, 3, 0],
|
||||
'101x': [0, 2, 3, 0]}[depth]
|
||||
model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck,
|
||||
model = ResNeXt(last_stride, bn_norm, with_ibn, with_nl, Bottleneck,
|
||||
num_blocks_per_stage, nl_layers_per_stage)
|
||||
if pretrain:
|
||||
if pretrain_path:
|
||||
|
|
Loading…
Reference in New Issue