diff --git a/fastreid/layers/batch_norm.py b/fastreid/layers/batch_norm.py index dddecf6..e0e88e3 100644 --- a/fastreid/layers/batch_norm.py +++ b/fastreid/layers/batch_norm.py @@ -22,7 +22,7 @@ __all__ = [ class BatchNorm(nn.BatchNorm2d): def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, - bias_init=0.0): + bias_init=0.0, **kwargs): super().__init__(num_features, eps=eps, momentum=momentum) if weight_init is not None: nn.init.constant_(self.weight, weight_init) if bias_init is not None: nn.init.constant_(self.bias, bias_init) @@ -34,20 +34,20 @@ class SyncBatchNorm(nn.SyncBatchNorm): def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0): super().__init__(num_features, eps=eps, momentum=momentum) - if weight_init is not None: self.weight.data.fill_(weight_init) - if bias_init is not None: self.bias.data.fill_(bias_init) + if weight_init is not None: nn.init.constant_(self.weight, weight_init) + if bias_init is not None: nn.init.constant_(self.bias, bias_init) self.weight.requires_grad_(not weight_freeze) self.bias.requires_grad_(not bias_freeze) class IBN(nn.Module): - def __init__(self, planes, bn_norm, num_splits): + def __init__(self, planes, bn_norm, **kwargs): super(IBN, self).__init__() half1 = int(planes / 2) self.half = half1 half2 = planes - half1 self.IN = nn.InstanceNorm2d(half1, affine=True) - self.BN = get_norm(bn_norm, half2, num_splits) + self.BN = get_norm(bn_norm, half2, **kwargs) def forward(self, x): split = torch.split(x, self.half, 1) @@ -100,8 +100,8 @@ class FrozenBatchNorm(BatchNorm): _version = 3 - def __init__(self, num_features, eps=1e-5): - super().__init__(num_features, weight_freeze=True, bias_freeze=True) + def __init__(self, num_features, eps=1e-5, **kwargs): + super().__init__(num_features, weight_freeze=True, bias_freeze=True, **kwargs) self.num_features = num_features self.eps = eps @@ -184,10 +184,14 @@ class FrozenBatchNorm(BatchNorm): return res -def get_norm(norm, out_channels, num_splits=1, **kwargs): +def get_norm(norm, out_channels, **kwargs): """ Args: - norm (str or callable): + norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN; + or a callable that thakes a channel number and returns + the normalization layer as a nn.Module + out_channels: number of channels for normalization layer + Returns: nn.Module or None: the normalization layer """ @@ -195,10 +199,10 @@ def get_norm(norm, out_channels, num_splits=1, **kwargs): if len(norm) == 0: return None norm = { - "BN": BatchNorm(out_channels, **kwargs), - "GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs), - "FrozenBN": FrozenBatchNorm(out_channels), - "GN": nn.GroupNorm(32, out_channels), - "syncBN": SyncBatchNorm(out_channels, **kwargs), + "BN": BatchNorm, + "GhostBN": GhostBatchNorm, + "FrozenBN": FrozenBatchNorm, + "GN": lambda channels, **args: nn.GroupNorm(32, channels), + "syncBN": SyncBatchNorm, }[norm] - return norm + return norm(out_channels, **kwargs) diff --git a/fastreid/layers/non_local.py b/fastreid/layers/non_local.py index 876ec43..888f928 100644 --- a/fastreid/layers/non_local.py +++ b/fastreid/layers/non_local.py @@ -7,7 +7,7 @@ from .batch_norm import get_norm class Non_local(nn.Module): - def __init__(self, in_channels, bn_norm, num_splits, reduc_ratio=2): + def __init__(self, in_channels, bn_norm, reduc_ratio=2): super(Non_local, self).__init__() self.in_channels = in_channels @@ -19,7 +19,7 @@ class Non_local(nn.Module): self.W = nn.Sequential( nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), - get_norm(bn_norm, self.in_channels, num_splits), + get_norm(bn_norm, self.in_channels), ) nn.init.constant_(self.W[1].weight, 0.0) nn.init.constant_(self.W[1].bias, 0.0) @@ -31,10 +31,10 @@ class Non_local(nn.Module): kernel_size=1, stride=1, padding=0) def forward(self, x): - ''' + """ :param x: (b, t, h, w) :return x: (b, t, h, w) - ''' + """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) diff --git a/fastreid/modeling/backbones/osnet.py b/fastreid/modeling/backbones/osnet.py index 6414a7a..26908ec 100644 --- a/fastreid/modeling/backbones/osnet.py +++ b/fastreid/modeling/backbones/osnet.py @@ -43,7 +43,6 @@ class ConvLayer(nn.Module): out_channels, kernel_size, bn_norm, - num_splits, stride=1, padding=0, groups=1, @@ -62,7 +61,7 @@ class ConvLayer(nn.Module): if IN: self.bn = nn.InstanceNorm2d(out_channels, affine=True) else: - self.bn = get_norm(bn_norm, out_channels, num_splits) + self.bn = get_norm(bn_norm, out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): @@ -75,7 +74,7 @@ class ConvLayer(nn.Module): class Conv1x1(nn.Module): """1x1 convolution + bn + relu.""" - def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1, groups=1): + def __init__(self, in_channels, out_channels, bn_norm, stride=1, groups=1): super(Conv1x1, self).__init__() self.conv = nn.Conv2d( in_channels, @@ -86,7 +85,7 @@ class Conv1x1(nn.Module): bias=False, groups=groups ) - self.bn = get_norm(bn_norm, out_channels, num_splits) + self.bn = get_norm(bn_norm, out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): @@ -99,12 +98,12 @@ class Conv1x1(nn.Module): class Conv1x1Linear(nn.Module): """1x1 convolution + bn (w/o non-linearity).""" - def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1): + def __init__(self, in_channels, out_channels, bn_norm, stride=1): super(Conv1x1Linear, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=stride, padding=0, bias=False ) - self.bn = get_norm(bn_norm, out_channels, num_splits) + self.bn = get_norm(bn_norm, out_channels) def forward(self, x): x = self.conv(x) @@ -115,7 +114,7 @@ class Conv1x1Linear(nn.Module): class Conv3x3(nn.Module): """3x3 convolution + bn + relu.""" - def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1, groups=1): + def __init__(self, in_channels, out_channels, bn_norm, stride=1, groups=1): super(Conv3x3, self).__init__() self.conv = nn.Conv2d( in_channels, @@ -126,7 +125,7 @@ class Conv3x3(nn.Module): bias=False, groups=groups ) - self.bn = get_norm(bn_norm, out_channels, num_splits) + self.bn = get_norm(bn_norm, out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): @@ -141,7 +140,7 @@ class LightConv3x3(nn.Module): 1x1 (linear) + dw 3x3 (nonlinear). """ - def __init__(self, in_channels, out_channels, bn_norm, num_splits): + def __init__(self, in_channels, out_channels, bn_norm): super(LightConv3x3, self).__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, 1, stride=1, padding=0, bias=False @@ -155,7 +154,7 @@ class LightConv3x3(nn.Module): bias=False, groups=out_channels ) - self.bn = get_norm(bn_norm, out_channels, num_splits) + self.bn = get_norm(bn_norm, out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): @@ -235,35 +234,34 @@ class OSBlock(nn.Module): in_channels, out_channels, bn_norm, - num_splits, IN=False, bottleneck_reduction=4, **kwargs ): super(OSBlock, self).__init__() mid_channels = out_channels // bottleneck_reduction - self.conv1 = Conv1x1(in_channels, mid_channels, bn_norm, num_splits) - self.conv2a = LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits) + self.conv1 = Conv1x1(in_channels, mid_channels, bn_norm) + self.conv2a = LightConv3x3(mid_channels, mid_channels, bn_norm) self.conv2b = nn.Sequential( - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), + LightConv3x3(mid_channels, mid_channels, bn_norm), + LightConv3x3(mid_channels, mid_channels, bn_norm), ) self.conv2c = nn.Sequential( - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), + LightConv3x3(mid_channels, mid_channels, bn_norm), + LightConv3x3(mid_channels, mid_channels, bn_norm), + LightConv3x3(mid_channels, mid_channels, bn_norm), ) self.conv2d = nn.Sequential( - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), - LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits), + LightConv3x3(mid_channels, mid_channels, bn_norm), + LightConv3x3(mid_channels, mid_channels, bn_norm), + LightConv3x3(mid_channels, mid_channels, bn_norm), + LightConv3x3(mid_channels, mid_channels, bn_norm), ) self.gate = ChannelGate(mid_channels) - self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn_norm, num_splits) + self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn_norm) self.downsample = None if in_channels != out_channels: - self.downsample = Conv1x1Linear(in_channels, out_channels, bn_norm, num_splits) + self.downsample = Conv1x1Linear(in_channels, out_channels, bn_norm) self.IN = None if IN: self.IN = nn.InstanceNorm2d(out_channels, affine=True) self.relu = nn.ReLU(True) @@ -303,7 +301,6 @@ class OSNet(nn.Module): layers, channels, bn_norm, - num_splits, IN=False, **kwargs ): @@ -313,7 +310,7 @@ class OSNet(nn.Module): assert num_blocks == len(channels) - 1 # convolutional backbone - self.conv1 = ConvLayer(3, channels[0], 7, bn_norm, num_splits, stride=2, padding=3, IN=IN) + self.conv1 = ConvLayer(3, channels[0], 7, bn_norm, stride=2, padding=3, IN=IN) self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) self.conv2 = self._make_layer( blocks[0], @@ -321,7 +318,6 @@ class OSNet(nn.Module): channels[0], channels[1], bn_norm, - num_splits, reduce_spatial_size=True, IN=IN ) @@ -331,7 +327,6 @@ class OSNet(nn.Module): channels[1], channels[2], bn_norm, - num_splits, reduce_spatial_size=True ) self.conv4 = self._make_layer( @@ -340,10 +335,9 @@ class OSNet(nn.Module): channels[2], channels[3], bn_norm, - num_splits, reduce_spatial_size=False ) - self.conv5 = Conv1x1(channels[3], channels[3], bn_norm, num_splits) + self.conv5 = Conv1x1(channels[3], channels[3], bn_norm) self._init_params() @@ -354,20 +348,19 @@ class OSNet(nn.Module): in_channels, out_channels, bn_norm, - num_splits, reduce_spatial_size, IN=False ): layers = [] - layers.append(block(in_channels, out_channels, bn_norm, num_splits, IN=IN)) + layers.append(block(in_channels, out_channels, bn_norm, IN=IN)) for i in range(1, layer): - layers.append(block(out_channels, out_channels, bn_norm, num_splits, IN=IN)) + layers.append(block(out_channels, out_channels, bn_norm, IN=IN)) if reduce_spatial_size: layers.append( nn.Sequential( - Conv1x1(out_channels, out_channels, bn_norm, num_splits), + Conv1x1(out_channels, out_channels, bn_norm), nn.AvgPool2d(2, stride=2), ) ) @@ -498,12 +491,11 @@ def build_osnet_backbone(cfg): """ # fmt: off - pretrain = cfg.MODEL.BACKBONE.PRETRAIN + pretrain = cfg.MODEL.BACKBONE.PRETRAIN pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH - with_ibn = cfg.MODEL.BACKBONE.WITH_IBN - bn_norm = cfg.MODEL.BACKBONE.NORM - num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT - depth = cfg.MODEL.BACKBONE.DEPTH + with_ibn = cfg.MODEL.BACKBONE.WITH_IBN + bn_norm = cfg.MODEL.BACKBONE.NORM + depth = cfg.MODEL.BACKBONE.DEPTH # fmt: on num_blocks_per_stage = [2, 2, 2] @@ -513,7 +505,7 @@ def build_osnet_backbone(cfg): "x0_5": [32, 128, 192, 256], "x0_25": [16, 64, 96, 128]}[depth] model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage, - bn_norm, num_splits, IN=with_ibn) + bn_norm, IN=with_ibn) if pretrain: # Load pretrain path if specifically diff --git a/fastreid/modeling/backbones/resnest.py b/fastreid/modeling/backbones/resnest.py index 5495108..3ddc39c 100644 --- a/fastreid/modeling/backbones/resnest.py +++ b/fastreid/modeling/backbones/resnest.py @@ -46,7 +46,7 @@ class Bottleneck(nn.Module): # pylint: disable=unused-argument expansion = 4 - def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, stride=1, downsample=None, + def __init__(self, inplanes, planes, bn_norm, with_ibn=False, stride=1, downsample=None, radix=1, cardinality=1, bottleneck_width=64, avd=False, avd_first=False, dilation=1, is_first=False, rectified_conv=False, rectify_avg=False, @@ -55,9 +55,9 @@ class Bottleneck(nn.Module): group_width = int(planes * (bottleneck_width / 64.)) * cardinality self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) if with_ibn: - self.bn1 = IBN(group_width, bn_norm, num_splits) + self.bn1 = IBN(group_width, bn_norm) else: - self.bn1 = get_norm(bn_norm, group_width, num_splits) + self.bn1 = get_norm(bn_norm, group_width) self.dropblock_prob = dropblock_prob self.radix = radix self.avd = avd and (stride > 1 or is_first) @@ -74,7 +74,7 @@ class Bottleneck(nn.Module): dilation=dilation, groups=cardinality, bias=False, radix=radix, rectify=rectified_conv, rectify_avg=rectify_avg, - norm_layer=bn_norm, num_splits=num_splits, + norm_layer=bn_norm, dropblock_prob=dropblock_prob) elif rectified_conv: from rfconv import RFConv2d @@ -83,17 +83,17 @@ class Bottleneck(nn.Module): padding=dilation, dilation=dilation, groups=cardinality, bias=False, average_mode=rectify_avg) - self.bn2 = get_norm(bn_norm, group_width, num_splits) + self.bn2 = get_norm(bn_norm, group_width) else: self.conv2 = nn.Conv2d( group_width, group_width, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=cardinality, bias=False) - self.bn2 = get_norm(bn_norm, group_width, num_splits) + self.bn2 = get_norm(bn_norm, group_width) self.conv3 = nn.Conv2d( group_width, planes * 4, kernel_size=1, bias=False) - self.bn3 = get_norm(bn_norm, planes * 4, num_splits) + self.bn3 = get_norm(bn_norm, planes * 4) if last_gamma: from torch.nn.init import zeros_ @@ -161,7 +161,7 @@ class ResNest(nn.Module): """ # pylint: disable=unused-variable - def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers, radix=1, + def __init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers, radix=1, groups=1, bottleneck_width=64, dilated=False, dilation=1, @@ -193,35 +193,35 @@ class ResNest(nn.Module): if deep_stem: self.conv1 = nn.Sequential( conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs), - get_norm(bn_norm, stem_width, num_splits), + get_norm(bn_norm, stem_width), nn.ReLU(inplace=True), conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), - get_norm(bn_norm, stem_width, num_splits), + get_norm(bn_norm, stem_width), nn.ReLU(inplace=True), conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), ) else: self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, bias=False, **conv_kwargs) - self.bn1 = get_norm(bn_norm, self.inplanes, num_splits) + self.bn1 = get_norm(bn_norm, self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn, is_first=False) - self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn) + self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn=with_ibn, is_first=False) + self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn=with_ibn) if dilated or dilation == 4: - self.layer3 = self._make_layer(block, 256, layers[2], 1, bn_norm, num_splits, with_ibn=with_ibn, + self.layer3 = self._make_layer(block, 256, layers[2], 1, bn_norm, with_ibn=with_ibn, dilation=2, dropblock_prob=dropblock_prob) - self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn, + self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, with_ibn=with_ibn, dilation=4, dropblock_prob=dropblock_prob) elif dilation == 2: - self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn, + self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn, dilation=1, dropblock_prob=dropblock_prob) - self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn, + self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, with_ibn=with_ibn, dilation=2, dropblock_prob=dropblock_prob) else: - self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn, + self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn, dropblock_prob=dropblock_prob) - self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn, + self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_ibn=with_ibn, dropblock_prob=dropblock_prob) for m in self.modules(): @@ -232,12 +232,12 @@ class ResNest(nn.Module): m.weight.data.fill_(1) m.bias.data.zero_() - if with_nl: - self._build_nonlocal(layers, non_layers, bn_norm, num_splits) - else: - self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = [] + # fmt: off + if with_nl: self._build_nonlocal(layers, non_layers, bn_norm) + else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = [] + # fmt: on - def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False, + def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, dilation=1, dropblock_prob=0.0, is_first=True): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: @@ -254,12 +254,12 @@ class ResNest(nn.Module): else: down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False)) - down_layers.append(get_norm(bn_norm, planes * block.expansion, num_splits)) + down_layers.append(get_norm(bn_norm, planes * block.expansion)) downsample = nn.Sequential(*down_layers) layers = [] if dilation == 1 or dilation == 2: - layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample, + layers.append(block(self.inplanes, planes, bn_norm, with_ibn, stride, downsample=downsample, radix=self.radix, cardinality=self.cardinality, bottleneck_width=self.bottleneck_width, avd=self.avd, avd_first=self.avd_first, @@ -268,7 +268,7 @@ class ResNest(nn.Module): dropblock_prob=dropblock_prob, last_gamma=self.last_gamma)) elif dilation == 4: - layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample, + layers.append(block(self.inplanes, planes, bn_norm, with_ibn, stride, downsample=downsample, radix=self.radix, cardinality=self.cardinality, bottleneck_width=self.bottleneck_width, avd=self.avd, avd_first=self.avd_first, @@ -281,7 +281,7 @@ class ResNest(nn.Module): self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, + layers.append(block(self.inplanes, planes, bn_norm, with_ibn, radix=self.radix, cardinality=self.cardinality, bottleneck_width=self.bottleneck_width, avd=self.avd, avd_first=self.avd_first, @@ -292,18 +292,18 @@ class ResNest(nn.Module): return nn.Sequential(*layers) - def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits): + def _build_nonlocal(self, layers, non_layers, bn_norm): self.NL_1 = nn.ModuleList( - [Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])]) + [Non_local(256, bn_norm) for _ in range(non_layers[0])]) self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) self.NL_2 = nn.ModuleList( - [Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])]) + [Non_local(512, bn_norm) for _ in range(non_layers[1])]) self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) self.NL_3 = nn.ModuleList( - [Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])]) + [Non_local(1024, bn_norm) for _ in range(non_layers[2])]) self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) self.NL_4 = nn.ModuleList( - [Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])]) + [Non_local(2048, bn_norm) for _ in range(non_layers[3])]) self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) def forward(self, x): @@ -364,21 +364,38 @@ def build_resnest_backbone(cfg): """ # fmt: off - pretrain = cfg.MODEL.BACKBONE.PRETRAIN + pretrain = cfg.MODEL.BACKBONE.PRETRAIN pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH - last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE - bn_norm = cfg.MODEL.BACKBONE.NORM - num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT - with_ibn = cfg.MODEL.BACKBONE.WITH_IBN - with_se = cfg.MODEL.BACKBONE.WITH_SE - with_nl = cfg.MODEL.BACKBONE.WITH_NL - depth = cfg.MODEL.BACKBONE.DEPTH + last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE + bn_norm = cfg.MODEL.BACKBONE.NORM + with_ibn = cfg.MODEL.BACKBONE.WITH_IBN + with_se = cfg.MODEL.BACKBONE.WITH_SE + with_nl = cfg.MODEL.BACKBONE.WITH_NL + depth = cfg.MODEL.BACKBONE.DEPTH + # fmt: on - num_blocks_per_stage = {"50x": [3, 4, 6, 3], "101x": [3, 4, 23, 3], "200x": [3, 24, 36, 3], - "269x": [3, 30, 48, 8]}[depth] - nl_layers_per_stage = {"50x": [0, 2, 3, 0], "101x": [0, 2, 3, 0], "200x": [0, 2, 3, 0], "269x": [0, 2, 3, 0]}[depth] - stem_width = {"50x": 32, "101x": 64, "200x": 64, "269x": 64}[depth] - model = ResNest(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck, num_blocks_per_stage, + num_blocks_per_stage = { + "50x": [3, 4, 6, 3], + "101x": [3, 4, 23, 3], + "200x": [3, 24, 36, 3], + "269x": [3, 30, 48, 8], + }[depth] + + nl_layers_per_stage = { + "50x": [0, 2, 3, 0], + "101x": [0, 2, 3, 0], + "200x": [0, 2, 3, 0], + "269x": [0, 2, 3, 0], + }[depth] + + stem_width = { + "50x": 32, + "101x": 64, + "200x": 64, + "269x": 64, + }[depth] + + model = ResNest(last_stride, bn_norm, with_ibn, with_nl, Bottleneck, num_blocks_per_stage, nl_layers_per_stage, radix=2, groups=1, bottleneck_width=64, deep_stem=True, stem_width=stem_width, avg_down=True, avd=True, avd_first=False) diff --git a/fastreid/modeling/backbones/resnet.py b/fastreid/modeling/backbones/resnet.py index bf938df..683dbb9 100644 --- a/fastreid/modeling/backbones/resnet.py +++ b/fastreid/modeling/backbones/resnet.py @@ -38,16 +38,16 @@ model_urls = { class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, with_se=False, + def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) if with_ibn: - self.bn1 = IBN(planes, bn_norm, num_splits) + self.bn1 = IBN(planes, bn_norm) else: - self.bn1 = get_norm(bn_norm, planes, num_splits) + self.bn1 = get_norm(bn_norm, planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn2 = get_norm(bn_norm, planes, num_splits) + self.bn2 = get_norm(bn_norm, planes) self.relu = nn.ReLU(inplace=True) if with_se: self.se = SELayer(planes, reduction) @@ -78,19 +78,19 @@ class BasicBlock(nn.Module): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, with_se=False, + def __init__(self, inplanes, planes, bn_norm, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) if with_ibn: - self.bn1 = IBN(planes, bn_norm, num_splits) + self.bn1 = IBN(planes, bn_norm) else: - self.bn1 = get_norm(bn_norm, planes, num_splits) + self.bn1 = get_norm(bn_norm, planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn2 = get_norm(bn_norm, planes, num_splits) + self.bn2 = get_norm(bn_norm, planes) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) - self.bn3 = get_norm(bn_norm, planes * self.expansion, num_splits) + self.bn3 = get_norm(bn_norm, planes * self.expansion) self.relu = nn.ReLU(inplace=True) if with_se: self.se = SELayer(planes * self.expansion, reduction) @@ -124,56 +124,56 @@ class Bottleneck(nn.Module): class ResNet(nn.Module): - def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block, layers, non_layers): + def __init__(self, last_stride, bn_norm, with_ibn, with_se, with_nl, block, layers, non_layers): self.inplanes = 64 super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = get_norm(bn_norm, 64, num_splits) + self.bn1 = get_norm(bn_norm, 64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) - self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn, with_se) - self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn, with_se) - self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn, with_se) - self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_se=with_se) + self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn, with_se) + self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn, with_se) + self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn, with_se) + self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_se=with_se) self.random_init() # fmt: off - if with_nl: self._build_nonlocal(layers, non_layers, bn_norm, num_splits) + if with_nl: self._build_nonlocal(layers, non_layers, bn_norm) else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = [] # fmt: on - def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False, with_se=False): + def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False, with_se=False): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - get_norm(bn_norm, planes * block.expansion, num_splits), + get_norm(bn_norm, planes * block.expansion), ) layers = [] - layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, with_se, stride, downsample)) + layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, with_se)) + layers.append(block(self.inplanes, planes, bn_norm, with_ibn, with_se)) return nn.Sequential(*layers) - def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits): + def _build_nonlocal(self, layers, non_layers, bn_norm): self.NL_1 = nn.ModuleList( - [Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])]) + [Non_local(256, bn_norm) for _ in range(non_layers[0])]) self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) self.NL_2 = nn.ModuleList( - [Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])]) + [Non_local(512, bn_norm) for _ in range(non_layers[1])]) self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) self.NL_3 = nn.ModuleList( - [Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])]) + [Non_local(1024, bn_norm) for _ in range(non_layers[2])]) self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) self.NL_4 = nn.ModuleList( - [Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])]) + [Non_local(2048, bn_norm) for _ in range(non_layers[3])]) self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) def forward(self, x): @@ -298,7 +298,6 @@ def build_resnet_backbone(cfg): pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE bn_norm = cfg.MODEL.BACKBONE.NORM - num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT with_ibn = cfg.MODEL.BACKBONE.WITH_IBN with_se = cfg.MODEL.BACKBONE.WITH_SE with_nl = cfg.MODEL.BACKBONE.WITH_NL @@ -326,7 +325,7 @@ def build_resnet_backbone(cfg): '101x': Bottleneck }[depth] - model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block, + model = ResNet(last_stride, bn_norm, with_ibn, with_se, with_nl, block, num_blocks_per_stage, nl_layers_per_stage) if pretrain: # Load pretrain path if specifically diff --git a/fastreid/modeling/backbones/resnext.py b/fastreid/modeling/backbones/resnext.py index 22593ad..6d61125 100644 --- a/fastreid/modeling/backbones/resnext.py +++ b/fastreid/modeling/backbones/resnext.py @@ -30,7 +30,7 @@ class Bottleneck(nn.Module): """ expansion = 4 - def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn, baseWidth, cardinality, stride=1, + def __init__(self, inplanes, planes, bn_norm, with_ibn, baseWidth, cardinality, stride=1, downsample=None): """ Constructor Args: @@ -46,13 +46,13 @@ class Bottleneck(nn.Module): C = cardinality self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False) if with_ibn: - self.bn1 = IBN(D * C, bn_norm, num_splits) + self.bn1 = IBN(D * C, bn_norm) else: - self.bn1 = get_norm(bn_norm, D * C, num_splits) + self.bn1 = get_norm(bn_norm, D * C) self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) - self.bn2 = get_norm(bn_norm, D * C, num_splits) + self.bn2 = get_norm(bn_norm, D * C) self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) - self.bn3 = get_norm(bn_norm, planes * 4, num_splits) + self.bn3 = get_norm(bn_norm, planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample @@ -86,7 +86,7 @@ class ResNeXt(nn.Module): https://arxiv.org/pdf/1611.05431.pdf """ - def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers, + def __init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers, baseWidth=4, cardinality=32): """ Constructor Args: @@ -102,22 +102,22 @@ class ResNeXt(nn.Module): self.output_size = 64 self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) - self.bn1 = get_norm(bn_norm, 64, num_splits) + self.bn1 = get_norm(bn_norm, 64) self.relu = nn.ReLU(inplace=True) self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn) - self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn) - self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn) - self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn) + self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn=with_ibn) + self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn=with_ibn) + self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn) + self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_ibn=with_ibn) self.random_init() # fmt: off - if with_nl: self._build_nonlocal(layers, non_layers, bn_norm, num_splits) + if with_nl: self._build_nonlocal(layers, non_layers, bn_norm) else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = [] # fmt: on - def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', num_splits=1, with_ibn=False): + def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', with_ibn=False): """ Stack n bottleneck modules where n is inferred from the depth of the network. Args: block: block type used to construct ResNext @@ -131,33 +131,31 @@ class ResNeXt(nn.Module): downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - get_norm(bn_norm, planes * block.expansion, num_splits), + get_norm(bn_norm, planes * block.expansion), ) layers = [] - if planes == 512: - with_ibn = False - layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, + layers.append(block(self.inplanes, planes, bn_norm, with_ibn, self.baseWidth, self.cardinality, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append( - block(self.inplanes, planes, bn_norm, num_splits, with_ibn, self.baseWidth, self.cardinality, 1, None)) + block(self.inplanes, planes, bn_norm, with_ibn, self.baseWidth, self.cardinality, 1, None)) return nn.Sequential(*layers) - def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits): + def _build_nonlocal(self, layers, non_layers, bn_norm): self.NL_1 = nn.ModuleList( - [Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])]) + [Non_local(256, bn_norm) for _ in range(non_layers[0])]) self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) self.NL_2 = nn.ModuleList( - [Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])]) + [Non_local(512, bn_norm) for _ in range(non_layers[1])]) self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) self.NL_3 = nn.ModuleList( - [Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])]) + [Non_local(1024, bn_norm) for _ in range(non_layers[2])]) self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) self.NL_4 = nn.ModuleList( - [Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])]) + [Non_local(2048, bn_norm) for _ in range(non_layers[3])]) self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) def forward(self, x): @@ -285,7 +283,6 @@ def build_resnext_backbone(cfg): pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE bn_norm = cfg.MODEL.BACKBONE.NORM - num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT with_ibn = cfg.MODEL.BACKBONE.WITH_IBN with_nl = cfg.MODEL.BACKBONE.WITH_NL depth = cfg.MODEL.BACKBONE.DEPTH @@ -298,7 +295,7 @@ def build_resnext_backbone(cfg): nl_layers_per_stage = { '50x': [0, 2, 3, 0], '101x': [0, 2, 3, 0]}[depth] - model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck, + model = ResNeXt(last_stride, bn_norm, with_ibn, with_nl, Bottleneck, num_blocks_per_stage, nl_layers_per_stage) if pretrain: if pretrain_path: