feat(layers/norm): add ghost batchnorm

add a get_norm fucntion to easily change normalization between batchnorm, ghost bn and group bn
pull/49/head
liaoxingyu 2020-05-01 09:02:46 +08:00
parent 329764bb60
commit a2dcd7b4ab
25 changed files with 242 additions and 374 deletions

View File

@ -31,6 +31,10 @@ _C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
_C.MODEL.BACKBONE.DEPTH = 50
_C.MODEL.BACKBONE.LAST_STRIDE = 1
# Normalization method for the convolution layers.
_C.MODEL.BACKBONE.NORM = "BN"
# Mini-batch split of Ghost BN
_C.MODEL.BACKBONE.NORM_SPLIT = 1
# If use IBN block in backbone
_C.MODEL.BACKBONE.WITH_IBN = False
# If use SE block in backbone
@ -48,17 +52,23 @@ _C.MODEL.BACKBONE.PRETRAIN_PATH = ''
_C.MODEL.HEADS = CN()
_C.MODEL.HEADS.NAME = "BNneckHead"
# Normalization method for the convolution layers.
_C.MODEL.HEADS.NORM = "BN"
# Mini-batch split of Ghost BN
_C.MODEL.HEADS.NORM_SPLIT = 1
# Number of identity
_C.MODEL.HEADS.NUM_CLASSES = 751
# Input feature dimension
_C.MODEL.HEADS.IN_FEAT = 2048
# Reduction dimension in head
_C.MODEL.HEADS.REDUCTION_DIM = 512
# Triplet feature using feature before(after) bnneck
_C.MODEL.HEADS.NECK_FEAT = "before" # options: before, after
# Pooling layer type
_C.MODEL.HEADS.POOL_LAYER = 'avgpool'
_C.MODEL.HEADS.POOL_LAYER = "avgpool"
# Classification layer type
_C.MODEL.HEADS.CLS_LAYER = 'linear' # 'arcface' or 'circle'
_C.MODEL.HEADS.CLS_LAYER = "linear" # "arcface" or "circle"
# Margin and Scale for margin-based classification layer
_C.MODEL.HEADS.MARGIN = 0.15

View File

@ -343,8 +343,8 @@ class DefaultTrainer(SimpleTrainer):
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
# logger = logging.getLogger(__name__)
# logger.info("Model:\n{}".format(model))
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
return model
@classmethod
@ -412,7 +412,7 @@ class DefaultTrainer(SimpleTrainer):
results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
logger.info(f'prepare test set {dataset_name}')
logger.info(f'prepare test set')
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.

View File

@ -30,7 +30,6 @@ __all__ = [
"AutogradProfiler",
"EvalHook",
"PreciseBN",
"LRFinder",
"FreezeLayer",
]

View File

@ -10,7 +10,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['Mish', 'Swish', 'MemoryEfficientSwish', 'GELU']
__all__ = [
'Mish',
'Swish',
'MemoryEfficientSwish',
'GELU']
class Mish(nn.Module):

View File

@ -1,65 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from . import *
from ..modeling.model_utils import weights_init_kaiming
class AdaCos(nn.Module):
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.pool_layer = nn.Sequential(
pool_layer,
Flatten()
)
# bnneck
self.bnneck = NoBiasBatchNorm1d(in_feat)
self.bnneck.apply(weights_init_kaiming)
# classifier
self._s = math.sqrt(2) * math.log(self._num_classes - 1)
self._m = 0.50
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
def forward(self, features, targets=None):
global_feat = self.pool_layer(features)
bn_feat = self.bnneck(global_feat)
if not self.training:
return bn_feat
# normalize features
x = F.normalize(bn_feat)
# normalize weights
weight = F.normalize(self.weight)
# dot product
logits = F.linear(x, weight)
# feature re-scale
theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
with torch.no_grad():
B_avg = torch.where(one_hot < 1, torch.exp(self._s * logits), torch.zeros_like(logits))
B_avg = torch.sum(B_avg) / x.size(0)
# print(B_avg)
theta_med = torch.median(theta[one_hot == 1])
self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi / 4 * torch.ones_like(theta_med), theta_med))
pred_class_logits = self.s * logits
return pred_class_logits, global_feat

View File

@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from ..modeling.losses.loss_utils import one_hot
from fastreid.utils.one_hot import one_hot
class Arcface(nn.Module):

View File

@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from ..modeling.losses.loss_utils import one_hot
from fastreid.utils.one_hot import one_hot
class Circle(nn.Module):

View File

@ -3,10 +3,11 @@
import torch
from torch import nn
from .norm import get_norm
class Non_local(nn.Module):
def __init__(self, in_channels, reduc_ratio=2):
def __init__(self, in_channels, bn_norm, num_splits, reduc_ratio=2):
super(Non_local, self).__init__()
self.in_channels = in_channels
@ -18,7 +19,7 @@ class Non_local(nn.Module):
self.W = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels),
get_norm(bn_norm, self.in_channels, num_splits),
)
nn.init.constant_(self.W[1].weight, 0.0)
nn.init.constant_(self.W[1].bias, 0.0)

View File

@ -7,23 +7,32 @@
import torch
from torch import nn
__all__ = ["NoBiasBatchNorm1d", "IBN"]
__all__ = [
"BatchNorm",
"IBN",
"GhostBatchNorm",
"get_norm",
]
def NoBiasBatchNorm1d(in_features):
bn_layer = nn.BatchNorm1d(in_features)
bn_layer.bias.requires_grad_(False)
return bn_layer
class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
bias_init=0.0):
super().__init__(num_features, eps=eps, momentum=momentum)
if weight_init is not None: self.weight.data.fill_(weight_init)
if bias_init is not None: self.bias.data.fill_(bias_init)
self.weight.requires_grad = not weight_freeze
self.bias.requires_grad = not bias_freeze
class IBN(nn.Module):
def __init__(self, planes):
def __init__(self, planes, bn_norm, num_splits):
super(IBN, self).__init__()
half1 = int(planes / 2)
self.half = half1
half2 = planes - half1
self.IN = nn.InstanceNorm2d(half1, affine=True)
self.BN = nn.BatchNorm2d(half2)
self.BN = get_norm(bn_norm, half2, num_splits)
def forward(self, x):
split = torch.split(x, self.half, 1)
@ -31,3 +40,48 @@ class IBN(nn.Module):
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out
class GhostBatchNorm(BatchNorm):
def __init__(self, num_features, num_splits=1, **kwargs):
super().__init__(num_features, **kwargs)
self.num_splits = num_splits
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, input):
N, C, H, W = input.shape
if self.training or not self.track_running_stats:
self.running_mean = self.running_mean.repeat(self.num_splits)
self.running_var = self.running_var.repeat(self.num_splits)
outputs = nn.functional.batch_norm(
input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
True, self.momentum, self.eps).view(N, C, H, W)
self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
return outputs
else:
return nn.functional.batch_norm(
input, self.running_mean, self.running_var,
self.weight, self.bias, False, self.momentum, self.eps)
def get_norm(norm, out_channels, num_splits=1, **kwargs):
"""
Args:
norm (str or callable):
Returns:
nn.Module or None: the normalization layer
"""
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"BN": BatchNorm(out_channels, **kwargs),
"GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs),
# "FrozenBN": FrozenBatchNorm2d,
# "GN": lambda channels: nn.GroupNorm(32, channels),
# "nnSyncBN": nn.SyncBatchNorm, # keep for debugging
}[norm]
return norm

View File

@ -1,92 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from fastreid.modeling.model_utils import weights_init_kaiming
from ..layers import *
class OSM(nn.Module):
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.pool_layer = nn.Sequential(
pool_layer,
Flatten()
)
# bnneck
self.bnneck = NoBiasBatchNorm1d(in_feat)
self.bnneck.apply(weights_init_kaiming)
# classifier
self.alpha = 1.2 # margin of weighted contrastive loss, as mentioned in the paper
self.l = 0.5 # hyperparameter controlling weights of positive set and the negative set
# I haven't been able to figure out the use of \sigma CAA 0.18
self.osm_sigma = 0.8 # \sigma OSM (0.8) as mentioned in paper
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, features, targets=None):
global_feat = self.pool_layer(features)
bn_feat = self.bnneck(global_feat)
if not self.training:
return bn_feat
bn_feat = F.normalize(bn_feat)
n = bn_feat.size(0)
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(bn_feat, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, bn_feat, bn_feat.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability & pairwise distance, dij
S = torch.exp(-1.0 * torch.pow(dist, 2) / (self.osm_sigma * self.osm_sigma))
S_ = torch.clamp(self.alpha - dist, min=1e-12) # max (0 , \alpha - dij) # 1e-12, 0 may result in nan error
p_mask = targets.expand(n, n).eq(targets.expand(n, n).t()) # same label == 1
n_mask = torch.bitwise_not(p_mask) # oposite label == 1
S = S * p_mask.float()
S = S + S_ * n_mask.float()
denominator = torch.exp(F.linear(bn_feat, F.normalize(self.weight)))
A = [] # attention corresponding to each feature fector
for i in range(n):
a_i = denominator[i][targets[i]] / torch.sum(denominator[i])
A.append(a_i)
# a_i's
atten_class = torch.stack(A)
# a_ij's
A = torch.min(atten_class.expand(n, n),
atten_class.view(-1, 1).expand(n, n)) # pairwise minimum of attention weights
W = S * A
W_P = W * p_mask.float()
W_N = W * n_mask.float()
W_P = W_P * (1 - torch.eye(n,
n).float().cuda()) # dist between (xi,xi) not necessarily 0, avoiding precision error
W_N = W_N * (1 - torch.eye(n, n).float().cuda())
L_P = 1.0 / 2 * torch.sum(W_P * torch.pow(dist, 2)) / torch.sum(W_P)
L_N = 1.0 / 2 * torch.sum(W_N * torch.pow(S_, 2)) / torch.sum(W_N)
L = (1 - self.l) * L_P + self.l * L_N
return L, global_feat

View File

@ -1,67 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from . import *
from ..modeling.losses.loss_utils import one_hot
from ..modeling.model_utils import weights_init_kaiming
class QAMHead(nn.Module):
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.pool_layer = nn.Sequential(
pool_layer,
Flatten()
)
# bnneck
self.bnneck = NoBiasBatchNorm1d(in_feat)
self.bnneck.apply(weights_init_kaiming)
# classifier
# self.adaptive_s = False
self._s = 6.0
self._m = 0.50
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, features, targets=None):
global_feat = self.pool_layer(features)
bn_feat = self.bnneck(global_feat)
if not self.training:
return bn_feat
# get cos(theta)
cosine = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
# add margin
theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7)) # for numerical stability
# --------------------------- convert label to one-hot ---------------------------
targets = one_hot(targets, self._num_classes)
phi = (2 * np.pi - (theta + self._m)) ** 2
others = (2 * np.pi - theta) ** 2
pred_class_logits = targets * phi + (1.0 - targets) * others
# logits re-scale
pred_class_logits *= self._s
return pred_class_logits, global_feat

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
from torch import nn
from torch.nn import Conv2d, ReLU
from torch.nn.modules.utils import _pair
from fastreid.layers import get_norm
class SplAtConv2d(nn.Module):
@ -18,7 +19,7 @@ class SplAtConv2d(nn.Module):
def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
dilation=(1, 1), groups=1, bias=True,
radix=2, reduction_factor=4,
rectify=False, rectify_avg=False, norm_layer=None,
rectify=False, rectify_avg=False, norm_layer=None, num_splits=1,
dropblock_prob=0.0, **kwargs):
super(SplAtConv2d, self).__init__()
padding = _pair(padding)
@ -38,11 +39,11 @@ class SplAtConv2d(nn.Module):
groups=groups * radix, bias=bias, **kwargs)
self.use_bn = norm_layer is not None
if self.use_bn:
self.bn0 = norm_layer(channels * radix)
self.bn0 = get_norm(norm_layer, channels * radix, num_splits)
self.relu = ReLU(inplace=True)
self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
if self.use_bn:
self.bn1 = norm_layer(inter_channels)
self.bn1 = get_norm(norm_layer, inter_channels, num_splits)
self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality)
self.rsoftmax = rSoftMax(radix, groups)

View File

@ -9,8 +9,14 @@ import math
import torch
from torch import nn
from fastreid.layers import (
IBN,
Non_local,
SplAtConv2d,
get_norm,
)
from .build import BACKBONE_REGISTRY
from ...layers import SplAtConv2d, IBN, Non_local
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
@ -39,18 +45,18 @@ class Bottleneck(nn.Module):
# pylint: disable=unused-argument
expansion = 4
def __init__(self, inplanes, planes, with_ibn=False, stride=1, downsample=None,
def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, stride=1, downsample=None,
radix=1, cardinality=1, bottleneck_width=64,
avd=False, avd_first=False, dilation=1, is_first=False,
rectified_conv=False, rectify_avg=False,
norm_layer=None, dropblock_prob=0.0, last_gamma=False):
dropblock_prob=0.0, last_gamma=False):
super(Bottleneck, self).__init__()
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
if with_ibn:
self.bn1 = IBN(group_width)
self.bn1 = IBN(group_width, bn_norm, num_splits)
else:
self.bn1 = norm_layer(group_width)
self.bn1 = get_norm(bn_norm, group_width, num_splits)
self.dropblock_prob = dropblock_prob
self.radix = radix
self.avd = avd and (stride > 1 or is_first)
@ -67,7 +73,7 @@ class Bottleneck(nn.Module):
dilation=dilation, groups=cardinality, bias=False,
radix=radix, rectify=rectified_conv,
rectify_avg=rectify_avg,
norm_layer=norm_layer,
norm_layer=bn_norm, num_splits=num_splits,
dropblock_prob=dropblock_prob)
elif rectified_conv:
from rfconv import RFConv2d
@ -76,17 +82,17 @@ class Bottleneck(nn.Module):
padding=dilation, dilation=dilation,
groups=cardinality, bias=False,
average_mode=rectify_avg)
self.bn2 = norm_layer(group_width)
self.bn2 = get_norm(bn_norm, group_width, num_splits)
else:
self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width)
self.bn2 = get_norm(bn_norm, group_width, num_splits)
self.conv3 = nn.Conv2d(
group_width, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * 4)
self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
if last_gamma:
from torch.nn.init import zeros_
@ -154,14 +160,14 @@ class ResNest(nn.Module):
"""
# pylint: disable=unused-variable
def __init__(self, last_stride, with_ibn, with_nl, block, layers, non_layers, radix=1, groups=1,
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers, radix=1, groups=1,
bottleneck_width=64,
dilated=False, dilation=1,
deep_stem=False, stem_width=64, avg_down=False,
rectified_conv=False, rectify_avg=False,
avd=False, avd_first=False,
final_drop=0.0, dropblock_prob=0,
last_gamma=False, norm_layer=nn.BatchNorm2d):
last_gamma=False):
self.cardinality = groups
self.bottleneck_width = bottleneck_width
# ResNet-D params
@ -185,58 +191,52 @@ class ResNest(nn.Module):
if deep_stem:
self.conv1 = nn.Sequential(
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
norm_layer(stem_width),
get_norm(bn_norm, stem_width, num_splits),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
norm_layer(stem_width),
get_norm(bn_norm, stem_width, num_splits),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
)
else:
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
bias=False, **conv_kwargs)
self.bn1 = norm_layer(self.inplanes)
self.bn1 = get_norm(bn_norm, self.inplanes, num_splits)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn, norm_layer=norm_layer, is_first=False)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn, norm_layer=norm_layer)
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn, is_first=False)
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn)
if dilated or dilation == 4:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, with_ibn=with_ibn,
dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, with_ibn=with_ibn,
dilation=4, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer3 = self._make_layer(block, 256, layers[2], 1, bn_norm, num_splits, with_ibn=with_ibn,
dilation=2, dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn,
dilation=4, dropblock_prob=dropblock_prob)
elif dilation == 2:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn,
dilation=1, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, with_ibn=with_ibn,
dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn,
dilation=1, dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, num_splits, with_ibn=with_ibn,
dilation=2, dropblock_prob=dropblock_prob)
else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn,
norm_layer=norm_layer,
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn,
norm_layer=norm_layer,
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn,
dropblock_prob=dropblock_prob)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, norm_layer):
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if with_nl:
self._build_nonlocal(layers, non_layers)
self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
else:
self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, dilation=1, norm_layer=None,
dropblock_prob=0.0, is_first=True):
def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False,
dilation=1, dropblock_prob=0.0, is_first=True):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
down_layers = []
@ -252,58 +252,58 @@ class ResNest(nn.Module):
else:
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False))
down_layers.append(norm_layer(planes * block.expansion))
down_layers.append(get_norm(bn_norm, planes * block.expansion, num_splits))
downsample = nn.Sequential(*down_layers)
layers = []
if planes == 512:
with_ibn = False
if dilation == 1 or dilation == 2:
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample=downsample,
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
elif dilation == 4:
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample=downsample,
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, stride, downsample=downsample,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, with_ibn,
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=dilation, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
return nn.Sequential(*layers)
def _build_nonlocal(self, layers, non_layers):
def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits):
self.NL_1 = nn.ModuleList(
[Non_local(256) for _ in range(non_layers[0])])
[Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])])
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
self.NL_2 = nn.ModuleList(
[Non_local(512) for _ in range(non_layers[1])])
[Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])])
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
self.NL_3 = nn.ModuleList(
[Non_local(1024) for _ in range(non_layers[2])])
[Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])])
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
self.NL_4 = nn.ModuleList(
[Non_local(2048) for _ in range(non_layers[3])])
[Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])])
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
def forward(self, x):
@ -366,6 +366,8 @@ def build_resnest_backbone(cfg):
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
bn_norm = cfg.MODEL.BACKBONE.NORM
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
with_nl = cfg.MODEL.BACKBONE.WITH_NL
@ -374,8 +376,8 @@ def build_resnest_backbone(cfg):
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 200: [3, 24, 36, 3], 269: [3, 30, 48, 8]}[depth]
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
stem_width = {50: 32, 101: 64, 200: 64, 269: 64}[depth]
model = ResNest(last_stride, with_ibn, with_nl, Bottleneck, num_blocks_per_stage, nl_layers_per_stage,
radix=2, groups=1, bottleneck_width=64,
model = ResNest(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck, num_blocks_per_stage,
nl_layers_per_stage, radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=stem_width, avg_down=True,
avd=True, avd_first=False)
if pretrain:

View File

@ -11,8 +11,14 @@ import torch
from torch import nn
from torch.utils import model_zoo
from fastreid.layers import (
IBN,
SELayer,
Non_local,
get_norm,
)
from .build import BACKBONE_REGISTRY
from ...layers import IBN, SELayer, Non_local
model_urls = {
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
@ -28,18 +34,19 @@ __all__ = ['ResNet', 'Bottleneck']
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16):
def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn=False, with_se=False,
stride=1, downsample=None, reduction=16):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
if with_ibn:
self.bn1 = IBN(planes)
self.bn1 = IBN(planes, bn_norm, num_splits)
else:
self.bn1 = nn.BatchNorm2d(planes)
self.bn1 = get_norm(bn_norm, planes, num_splits)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.bn2 = get_norm(bn_norm, planes, num_splits)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
self.relu = nn.ReLU(inplace=True)
if with_se:
self.se = SELayer(planes * 4, reduction)
@ -73,59 +80,58 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, last_stride, with_ibn, with_se, with_nl, block, layers, non_layers):
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block, layers, non_layers):
scale = 64
self.inplanes = scale
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.bn1 = get_norm(bn_norm, 64, num_splits)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn, with_se=with_se)
self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, with_ibn=with_ibn, with_se=with_se)
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2, with_ibn=with_ibn, with_se=with_se)
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride, with_se=with_se)
self.layer1 = self._make_layer(block, scale, layers[0], 1, bn_norm, num_splits, with_ibn, with_se)
self.layer2 = self._make_layer(block, scale * 2, layers[1], 2, bn_norm, num_splits, with_ibn, with_se)
self.layer3 = self._make_layer(block, scale * 4, layers[2], 2, bn_norm, num_splits, with_ibn, with_se)
self.layer4 = self._make_layer(block, scale * 8, layers[3], last_stride, bn_norm, num_splits, with_se=with_se)
self.random_init()
if with_nl:
self._build_nonlocal(layers, non_layers)
self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
else:
self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, with_se=False):
def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False, with_se=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
get_norm(bn_norm, planes * block.expansion, num_splits),
)
layers = []
if planes == 512:
with_ibn = False
layers.append(block(self.inplanes, planes, with_ibn, with_se, stride, downsample))
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, with_se, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, with_ibn, with_se))
layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, with_se))
return nn.Sequential(*layers)
def _build_nonlocal(self, layers, non_layers):
def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits):
self.NL_1 = nn.ModuleList(
[Non_local(256) for _ in range(non_layers[0])])
[Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])])
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
self.NL_2 = nn.ModuleList(
[Non_local(512) for _ in range(non_layers[1])])
[Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])])
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
self.NL_3 = nn.ModuleList(
[Non_local(1024) for _ in range(non_layers[2])])
[Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])])
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
self.NL_4 = nn.ModuleList(
[Non_local(2048) for _ in range(non_layers[3])])
[Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])])
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
def forward(self, x):
@ -198,14 +204,17 @@ def build_resnet_backbone(cfg):
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
bn_norm = cfg.MODEL.BACKBONE.NORM
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
model = ResNet(last_stride, with_ibn, with_se, with_nl, Bottleneck, num_blocks_per_stage, nl_layers_per_stage)
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 9, 0]}[depth]
model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, Bottleneck,
num_blocks_per_stage, nl_layers_per_stage)
if pretrain:
if not with_ibn:
# original resnet

View File

@ -4,21 +4,19 @@
@contact: sherlockliao01@gmail.com
"""
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_kaiming
from .build import REID_HEADS_REGISTRY
from ..model_utils import weights_init_kaiming
from ...layers import *
@REID_HEADS_REGISTRY.register()
class BNneckHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
self.pool_layer = pool_layer
self.pool_layer = nn.Sequential(
pool_layer,
Flatten()
)
self.bnneck = NoBiasBatchNorm1d(in_feat)
self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
self.bnneck.apply(weights_init_kaiming)
# identity classification layer
@ -37,12 +35,20 @@ class BNneckHead(nn.Module):
"""
global_feat = self.pool_layer(features)
bn_feat = self.bnneck(global_feat)
# evaluation
bn_feat = Flatten()(bn_feat)
# Evaluation
if not self.training:
return bn_feat
# training
# Training
try:
pred_class_logits = self.classifier(bn_feat)
except TypeError:
pred_class_logits = self.classifier(bn_feat, targets)
return pred_class_logits, bn_feat, targets
if self.neck_feat == "before":
feat = Flatten()(global_feat)
elif self.neck_feat == "after":
feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
return pred_class_logits, feat, targets

View File

@ -4,19 +4,15 @@
@contact: sherlockliao01@gmail.com
"""
from fastreid.layers import *
from .build import REID_HEADS_REGISTRY
from ...layers import *
@REID_HEADS_REGISTRY.register()
class LinearHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
self.pool_layer = nn.Sequential(
pool_layer,
Flatten()
)
self.pool_layer = pool_layer
# identity classification layer
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
@ -33,6 +29,7 @@ class LinearHead(nn.Module):
See :class:`ReIDHeads.forward`.
"""
global_feat = self.pool_layer(features)
global_feat = Flatten()(global_feat)
if not self.training:
return global_feat
# training

View File

@ -4,29 +4,27 @@
@contact: sherlockliao01@gmail.com
"""
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_kaiming
from .build import REID_HEADS_REGISTRY
from ..model_utils import weights_init_kaiming
from ...layers import *
@REID_HEADS_REGISTRY.register()
class ReductionHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
super().__init__()
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
self.pool_layer = nn.Sequential(
pool_layer,
Flatten()
)
self.pool_layer = pool_layer
self.bottleneck = nn.Sequential(
nn.Linear(in_feat, reduction_dim, bias=False),
NoBiasBatchNorm1d(reduction_dim),
nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
BatchNorm(reduction_dim, bias_freeze=True),
nn.LeakyReLU(0.1),
nn.Dropout(0.5),
nn.Dropout2d(0.5),
)
self.bnneck = NoBiasBatchNorm1d(reduction_dim)
self.bnneck = BatchNorm(reduction_dim, bias_freeze=True)
self.bottleneck.apply(weights_init_kaiming)
self.bnneck.apply(weights_init_kaiming)
@ -48,11 +46,20 @@ class ReductionHead(nn.Module):
global_feat = self.pool_layer(features)
global_feat = self.bottleneck(global_feat)
bn_feat = self.bnneck(global_feat)
bn_feat = Flatten()(bn_feat)
# Evaluation
if not self.training:
return bn_feat
# training
# Training
try:
pred_class_logits = self.classifier(bn_feat)
except TypeError:
pred_class_logits = self.classifier(bn_feat, targets)
return pred_class_logits, bn_feat, targets
if self.neck_feat == "before":
feat = Flatten()(global_feat)
elif self.neck_feat == "after":
feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
return pred_class_logits, feat, targets

View File

@ -6,8 +6,7 @@
import torch
import torch.nn.functional as F
from ...utils.events import get_event_storage
from .loss_utils import one_hot
from fastreid.utils.events import get_event_storage
class CrossEntropyLoss(object):

View File

@ -7,7 +7,7 @@
import torch
import torch.nn.functional as F
from .loss_utils import one_hot
from fastreid.utils.one_hot import one_hot
# based on:

View File

@ -8,8 +8,11 @@ import torch
from torch import nn
import torch.nn.functional as F
__all__ = [
"TripletLoss",
"CircleLoss",
]
__all__ = ["TripletLoss", "CircleLoss"]
def normalize(x, axis=-1):
"""Normalizing to unit length along the specified dimension.

View File

@ -6,11 +6,11 @@
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.heads import build_reid_heads
from fastreid.modeling.losses import reid_losses
from .build import META_ARCH_REGISTRY
from ..backbones import build_backbone
from ..heads import build_reid_heads
from ..losses import reid_losses
from ...layers import GeneralizedMeanPoolingP
@META_ARCH_REGISTRY.register()

View File

@ -104,11 +104,8 @@ class Checkpointer(object):
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
checkpoint = self._load_file(path)
if self.dataset is None:
self.logger.info(
"No need to load dataset pid dictionary"
)
else:
if self.dataset is not None:
self.logger.info("Loading dataset pid dictionary from {}".format(path))
self._load_dataset_pid_dict(checkpoint)
self._load_model(checkpoint)
for key, obj in self.checkpointables.items():

View File

@ -8,14 +8,15 @@ from typing import Optional
import torch
# based on:
# https://github.com/kornia/kornia/blob/master/kornia/utils/one_hot.py
def one_hot(labels: torch.Tensor,
num_classes: int,
dtype: Optional[torch.dtype] = None,) -> torch.Tensor:
# eps: Optional[float] = 1e-6) -> torch.Tensor:
dtype: Optional[torch.dtype] = None, ) -> torch.Tensor:
# eps: Optional[float] = 1e-6) -> torch.Tensor:
r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor.
Args:
labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`,
@ -45,7 +46,7 @@ def one_hot(labels: torch.Tensor,
.format(type(labels)))
if not labels.dtype == torch.int64:
raise ValueError(
"labels must be of the same dtype torch.int64. Got: {}" .format(
"labels must be of the same dtype torch.int64. Got: {}".format(
labels.dtype))
if num_classes < 1:
raise ValueError("The number of classes must be bigger than one."
@ -54,4 +55,4 @@ def one_hot(labels: torch.Tensor,
shape = labels.shape
one_hot = torch.zeros(shape[0], num_classes, *shape[1:],
device=device, dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0)

View File

@ -8,7 +8,6 @@ import itertools
import torch
BN_MODULE_TYPES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
@ -42,7 +41,6 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
num_iters (int): number of iterations to compute the stats.
"""
bn_layers = get_bn_modules(model)
if len(bn_layers) == 0:
return
@ -59,9 +57,11 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
# Change targets to zero to avoid error in
# circle(arcface) loss which will use targets in forward
inputs['targets'].zero_()
with torch.no_grad(): # No need to backward
model(inputs)
for i, bn in enumerate(bn_layers):
# Accumulates the bn stats.
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
@ -91,8 +91,6 @@ def get_bn_modules(model):
"""
# Finds all the bn layers.
bn_layers = [
m
for m in model.modules()
if m.training and isinstance(m, BN_MODULE_TYPES)
m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES)
]
return bn_layers

View File

@ -1,11 +1,15 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
from torch import nn
__all__ = ['weights_init_classifier', 'weights_init_kaiming', ]
__all__ = [
'weights_init_classifier',
'weights_init_kaiming',
]
def weights_init_kaiming(m):