refactor(fastreid)

refactor architecture
pull/43/head
liaoxingyu 2020-04-20 10:59:29 +08:00
parent 9684500a57
commit 95a3c62ad2
64 changed files with 915 additions and 620 deletions

View File

@ -35,6 +35,8 @@ _C.MODEL.BACKBONE.LAST_STRIDE = 1
_C.MODEL.BACKBONE.WITH_IBN = False
# If use SE block in backbone
_C.MODEL.BACKBONE.WITH_SE = False
# If use Non-local block in backbone
_C.MODEL.BACKBONE.WITH_NL = False
# If use ImageNet pretrain model
_C.MODEL.BACKBONE.PRETRAIN = True
# Pretrain model path

View File

@ -0,0 +1,166 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from collections import defaultdict
import argparse
import json
import os
import sys
import time
import cv2
import numpy as np
import torch
from torch.backends import cudnn
from fastreid.modeling import build_model
from fastreid.utils.checkpoint import Checkpointer
from fastreid.config import get_cfg
cudnn.benchmark = True
class Reid(object):
def __init__(self, config_file):
cfg = get_cfg()
cfg.merge_from_file(config_file)
cfg.defrost()
cfg.MODEL.WEIGHTS = 'projects/bjzProject/logs/bjz/arcface_adam/model_final.pth'
model = build_model(cfg)
Checkpointer(model).resume_or_load(cfg.MODEL.WEIGHTS)
model.cuda()
model.eval()
self.model = model
# self.model = torch.jit.load("reid_model.pt")
# self.model.eval()
# self.model.cuda()
example = torch.rand(1, 3, 256, 128)
example = example.cuda()
traced_script_module = torch.jit.trace_module(model, {'inference': example})
traced_script_module.save("reid_feat_extractor.pt")
@classmethod
def preprocess(cls, img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (128, 256))
img = img / 255.0
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
img = img.transpose((2, 0, 1)).astype(np.float32)
img = img[np.newaxis, :, :, :]
data = torch.from_numpy(img).cuda().float()
return data
@torch.no_grad()
def demo(self, img_path):
data = self.preprocess(img_path)
output = self.model.inference(data)
feat = output.cpu().data.numpy()
return feat
# @torch.no_grad()
# def extract_feat(self, dataloader):
# prefetcher = test_data_prefetcher(dataloader)
# feats = []
# labels = []
# batch = prefetcher.next()
# num_count = 0
# while batch[0] is not None:
# img, pid, camid = batch
# feat = self.model(img)
# feats.append(feat.cpu())
# labels.extend(np.asarray(pid))
#
# # if num_count > 2:
# # break
# batch = prefetcher.next()
# # num_count += 1
#
# feats = torch.cat(feats, dim=0)
# id_feats = defaultdict(list)
# for f, i in zip(feats, labels):
# id_feats[i].append(f)
# all_feats = []
# label_names = []
# for i in id_feats:
# all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
# label_names.append(i)
#
# label_names = np.asarray(label_names)
# all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
# all_feats = F.normalize(all_feats, p=2, dim=1)
# np.save('feats.npy', all_feats.cpu())
# np.save('labels.npy', label_names)
# cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
# cos -= np.eye(all_feats.shape[0])
# f = open('check_cross_folder_similarity.txt', 'w')
# for i in range(len(label_names)):
# sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
# sim_name = label_names[sim_indx]
# write_str = label_names[i] + ' '
# # f.write(label_names[i]+'\t')
# for n in sim_name:
# write_str += (n + ' ')
# # f.write(n+'\t')
# f.write(write_str+'\n')
#
#
# def prepare_gt(self, json_file):
# feat = []
# label = []
# with open(json_file, 'r') as f:
# total = json.load(f)
# for index in total:
# label.append(index)
# feat.append(np.array(total[index]))
# time_label = [int(i[0:10]) for i in label]
#
# return np.array(feat), np.array(label), np.array(time_label)
def compute_topk(self, k, feat, feats, label):
# num_gallery = feats.shape[0]
# new_feat = np.tile(feat,[num_gallery,1])
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
matrix = np.sum(np.multiply(feat, feats), axis=-1)
dist = matrix / np.multiply(norm_feat, norm_feats)
# print('feat:',feat.shape)
# print('feats:',feats.shape)
# print('label:',label.shape)
# print('dist:',dist.shape)
index = np.argsort(-dist)
# print('index:',index.shape)
result = []
for i in range(min(feats.shape[0], k)):
print(dist[index[i]])
result.append(label[index[i]])
return result
if __name__ == '__main__':
reid_sys = Reid(config_file='../../projects/bjzProject/configs/bjz.yml')
img_path = '/export/home/lxy/beijingStationReID/reid_model/demo_imgs/003740_c5s2_1561733125170.000000.jpg'
feat = reid_sys.demo(img_path)
feat_extractor = torch.jit.load('reid_feat_extractor.pt')
data = reid_sys.preprocess(img_path)
feat2 = feat_extractor.inference(data)
from ipdb import set_trace; set_trace()
# imgs = os.listdir(img_path)
# feats = {}
# for i in range(len(imgs)):
# feat = reid.demo(os.path.join(img_path, imgs[i]))
# feats[imgs[i]] = feat
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
# out1 = feats['dog.jpg']
# out2 = feats['kobe2.jpg']
# innerProduct = np.dot(out1, out2.T)
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')

View File

@ -7,8 +7,10 @@ from torch import nn
from .batch_drop import BatchDrop
from .attention import *
from .batch_norm import NoBiasBatchNorm1d
from .norm import *
from .context_block import ContextBlock
from .non_local import Non_local
from .se_layer import SELayer
from .frn import FRN, TLU
from .mish import Mish
from .gem_pool import GeneralizedMeanPoolingP

View File

@ -10,8 +10,8 @@ import torch
import torch.nn.functional as F
from torch.nn import Parameter
from ..layers import *
from ..model_utils import weights_init_kaiming
from . import *
from ..modeling.model_utils import weights_init_kaiming
class AdaCos(nn.Module):

View File

@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from ..losses.loss_utils import one_hot
from ..modeling.losses.loss_utils import one_hot
class Arcface(nn.Module):

View File

@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from ..losses.loss_utils import one_hot
from ..modeling.losses.loss_utils import one_hot
class Circle(nn.Module):

View File

@ -0,0 +1,33 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
__all__ = ["NoBiasBatchNorm1d", "IBN"]
def NoBiasBatchNorm1d(in_features):
bn_layer = nn.BatchNorm1d(in_features)
bn_layer.bias.requires_grad_(False)
return bn_layer
class IBN(nn.Module):
def __init__(self, planes):
super(IBN, self).__init__()
half1 = int(planes / 2)
self.half = half1
half2 = planes - half1
self.IN = nn.InstanceNorm2d(half1, affine=True)
self.BN = nn.BatchNorm2d(half2)
def forward(self, x):
split = torch.split(x, self.half, 1)
out1 = self.IN(split[0].contiguous())
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out

View File

@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from ..model_utils import weights_init_kaiming
from fastreid.modeling.model_utils import weights_init_kaiming
from ..layers import *

View File

@ -11,9 +11,9 @@ import torch
import torch.nn.functional as F
from torch.nn import Parameter
from ..layers import *
from ..losses.loss_utils import one_hot
from ..model_utils import weights_init_kaiming
from . import *
from ..modeling.losses.loss_utils import one_hot
from ..modeling.model_utils import weights_init_kaiming
class QAMHead(nn.Module):

View File

@ -0,0 +1,25 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, int(channel / reduction), bias=False),
nn.ReLU(inplace=True),
nn.Linear(int(channel / reduction), channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)

View File

@ -0,0 +1,338 @@
# encoding: utf-8
# based on:
# https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py
"""ResNeSt models"""
import torch
from torch import nn
import math
import logging
from .resnet import ResNet, Bottleneck
from .build import BACKBONE_REGISTRY
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
_model_sha256 = {name: checksum for checksum, name in [
('528c19ca', '50'),
('22405ba7', '101'),
('75117900', '200'),
('0cc87c48', '269'),
]}
def short_hash(name):
if name not in _model_sha256:
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha256[name][:8]
resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for
name in _model_sha256.keys()
}
class Bottleneck(nn.Module):
"""ResNet Bottleneck
"""
# pylint: disable=unused-argument
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
radix=1, cardinality=1, bottleneck_width=64,
avd=False, avd_first=False, dilation=1, is_first=False,
rectified_conv=False, rectify_avg=False,
norm_layer=None, dropblock_prob=0.0, last_gamma=False):
super(Bottleneck, self).__init__()
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1 = norm_layer(group_width)
self.dropblock_prob = dropblock_prob
self.radix = radix
self.avd = avd and (stride > 1 or is_first)
self.avd_first = avd_first
if self.avd:
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
stride = 1
if radix > 1:
self.conv2 = SplAtConv2d(
group_width, group_width, kernel_size=3,
stride=stride, padding=dilation,
dilation=dilation, groups=cardinality, bias=False,
radix=radix, rectify=rectified_conv,
rectify_avg=rectify_avg,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif rectified_conv:
from rfconv import RFConv2d
self.conv2 = RFConv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False,
average_mode=rectify_avg)
self.bn2 = norm_layer(group_width)
else:
self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width)
self.conv3 = nn.Conv2d(
group_width, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes * 4)
if last_gamma:
from torch.nn.init import zeros_
zeros_(self.bn3.weight)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.dilation = dilation
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
if self.dropblock_prob > 0.0:
out = self.dropblock1(out)
out = self.relu(out)
if self.avd and self.avd_first:
out = self.avd_layer(out)
out = self.conv2(out)
if self.radix == 1:
out = self.bn2(out)
if self.dropblock_prob > 0.0:
out = self.dropblock2(out)
out = self.relu(out)
if self.avd and not self.avd_first:
out = self.avd_layer(out)
out = self.conv3(out)
out = self.bn3(out)
if self.dropblock_prob > 0.0:
out = self.dropblock3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNest(nn.Module):
"""ResNet Variants ResNest
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
layers : list of int
Numbers of layers in each block
classes : int, default 1000
Number of classification classes.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
Reference:
- He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
# pylint: disable=unused-variable
def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64,
num_classes=1000, dilated=False, dilation=1,
deep_stem=False, stem_width=64, avg_down=False,
rectified_conv=False, rectify_avg=False,
avd=False, avd_first=False,
final_drop=0.0, dropblock_prob=0,
last_gamma=False, norm_layer=nn.BatchNorm2d):
self.cardinality = groups
self.bottleneck_width = bottleneck_width
# ResNet-D params
self.inplanes = stem_width * 2 if deep_stem else 64
self.avg_down = avg_down
self.last_gamma = last_gamma
# ResNeSt params
self.radix = radix
self.avd = avd
self.avd_first = avd_first
super(ResNet, self).__init__()
self.rectified_conv = rectified_conv
self.rectify_avg = rectify_avg
if rectified_conv:
from rfconv import RFConv2d
conv_layer = RFConv2d
else:
conv_layer = nn.Conv2d
conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
if deep_stem:
self.conv1 = nn.Sequential(
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
)
else:
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
bias=False, **conv_kwargs)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
if dilated or dilation == 4:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=4, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
elif dilation == 2:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilation=1, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=2, norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
norm_layer=norm_layer,
dropblock_prob=dropblock_prob)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, norm_layer):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
dropblock_prob=0.0, is_first=True):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
down_layers = []
if self.avg_down:
if dilation == 1:
down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
ceil_mode=True, count_include_pad=False))
else:
down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
ceil_mode=True, count_include_pad=False))
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False))
else:
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False))
down_layers.append(norm_layer(planes * block.expansion))
downsample = nn.Sequential(*down_layers)
layers = []
if dilation == 1 or dilation == 2:
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes,
radix=self.radix, cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, avd_first=self.avd_first,
dilation=dilation, rectified_conv=self.rectified_conv,
rectify_avg=self.rectify_avg,
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
last_gamma=self.last_gamma))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
@BACKBONE_REGISTRY.register()
def build_resnest_backbone(cfg):
"""
Create a ResNest instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 200: [3, 24, 36, 3], 269: [3, 30, 48, 8]}[depth]
# model = ResNet(last_stride, with_ibn, with_se, Bottleneck, num_blocks_per_stage)
model = ResNest(Bottleneck, [3, 4, 6, 3],
radix=2, groups=1, bottleneck_width=64,
deep_stem=True, stem_width=32, avg_down=True,
avd=True, avd_first=False)
if pretrain:
if not with_ibn:
# original resnet
state_dict = torch.hub.load_state_dict_from_url(
resnest_model_urls[depth], progress=True, check_hash=True)
else:
raise KeyError('Not implementation ibn in resnest')
# # ibn resnet
# state_dict = torch.load(pretrain_path)['state_dict']
# # remove module in name
# new_state_dict = {}
# for k in state_dict:
# new_k = '.'.join(k.split('.')[1:])
# if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
# new_state_dict[new_k] = state_dict[k]
# state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__)
logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
return model

View File

@ -12,6 +12,7 @@ from torch import nn
from torch.utils import model_zoo
from .build import BACKBONE_REGISTRY
from ...layers import IBN, SELayer, Non_local
model_urls = {
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
@ -19,50 +20,11 @@ model_urls = {
50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
__all__ = ['ResNet', 'Bottleneck']
class IBN(nn.Module):
def __init__(self, planes):
super(IBN, self).__init__()
half1 = int(planes / 2)
self.half = half1
half2 = planes - half1
self.IN = nn.InstanceNorm2d(half1, affine=True)
self.BN = nn.BatchNorm2d(half2)
def forward(self, x):
split = torch.split(x, self.half, 1)
out1 = self.IN(split[0].contiguous())
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, int(channel / reduction), bias=False),
nn.ReLU(inplace=True),
nn.Linear(int(channel / reduction), channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class Bottleneck(nn.Module):
expansion = 4
@ -111,7 +73,7 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, last_stride, with_ibn, with_se, block, layers):
def __init__(self, last_stride, with_ibn, with_se, with_nl, block, layers, non_layers):
scale = 64
self.inplanes = scale
super().__init__()
@ -127,6 +89,12 @@ class ResNet(nn.Module):
self.random_init()
if with_nl:
self._build_nonlocal(layers, non_layers)
else:
self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, with_se=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
@ -146,16 +114,65 @@ class ResNet(nn.Module):
return nn.Sequential(*layers)
def _build_nonlocal(self, layers, non_layers):
self.NL_1 = nn.ModuleList(
[Non_local(256) for _ in range(non_layers[0])])
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
self.NL_2 = nn.ModuleList(
[Non_local(512) for _ in range(non_layers[1])])
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
self.NL_3 = nn.ModuleList(
[Non_local(1024) for _ in range(non_layers[2])])
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
self.NL_4 = nn.ModuleList(
[Non_local(2048) for _ in range(non_layers[3])])
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
NL1_counter = 0
if len(self.NL_1_idx) == 0:
self.NL_1_idx = [-1]
for i in range(len(self.layer1)):
x = self.layer1[i](x)
if i == self.NL_1_idx[NL1_counter]:
_, C, H, W = x.shape
x = self.NL_1[NL1_counter](x)
NL1_counter += 1
# Layer 2
NL2_counter = 0
if len(self.NL_2_idx) == 0:
self.NL_2_idx = [-1]
for i in range(len(self.layer2)):
x = self.layer2[i](x)
if i == self.NL_2_idx[NL2_counter]:
_, C, H, W = x.shape
x = self.NL_2[NL2_counter](x)
NL2_counter += 1
# Layer 3
NL3_counter = 0
if len(self.NL_3_idx) == 0:
self.NL_3_idx = [-1]
for i in range(len(self.layer3)):
x = self.layer3[i](x)
if i == self.NL_3_idx[NL3_counter]:
_, C, H, W = x.shape
x = self.NL_3[NL3_counter](x)
NL3_counter += 1
# Layer 4
NL4_counter = 0
if len(self.NL_4_idx) == 0:
self.NL_4_idx = [-1]
for i in range(len(self.layer4)):
x = self.layer4[i](x)
if i == self.NL_4_idx[NL4_counter]:
_, C, H, W = x.shape
x = self.NL_4[NL4_counter](x)
NL4_counter += 1
return x
@ -183,10 +200,12 @@ def build_resnet_backbone(cfg):
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
model = ResNet(last_stride, with_ibn, with_se, Bottleneck, num_blocks_per_stage)
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
model = ResNet(last_stride, with_ibn, with_se, with_nl, Bottleneck, num_blocks_per_stage, nl_layers_per_stage)
if pretrain:
if not with_ibn:
# original resnet
@ -206,5 +225,3 @@ def build_resnet_backbone(cfg):
logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
return model

View File

@ -5,8 +5,8 @@
"""
from .build import REID_HEADS_REGISTRY
from ..layers import *
from ..model_utils import weights_init_kaiming
from ...layers import *
@REID_HEADS_REGISTRY.register()

View File

@ -5,7 +5,7 @@
"""
from .build import REID_HEADS_REGISTRY
from ..layers import *
from ...layers import *
@REID_HEADS_REGISTRY.register()

View File

@ -5,8 +5,8 @@
"""
from .build import REID_HEADS_REGISTRY
from ..layers import *
from ..model_utils import weights_init_kaiming
from ...layers import *
@REID_HEADS_REGISTRY.register()

View File

@ -1,13 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
def NoBiasBatchNorm1d(in_features):
bn_layer = nn.BatchNorm1d(in_features)
bn_layer.bias.requires_grad_(False)
return bn_layer

View File

@ -1,26 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
class SEModule(nn.Module):
def __init__(self, channels, reduciton):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(channels, channels//reduciton, kernel_size=1, padding=0, bias=False)
self.relu = nn.ReLU(True)
self.fc2 = nn.Conv2d(channels//reduciton, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x

View File

@ -14,7 +14,7 @@ from .build import META_ARCH_REGISTRY
from ..backbones import build_backbone
from ..heads import build_reid_heads
from ..model_utils import weights_init_kaiming
from fastreid.modeling.layers import CAM_Module, PAM_Module, DANetHead, Flatten
from fastreid.layers import CAM_Module, PAM_Module, DANetHead, Flatten
@META_ARCH_REGISTRY.register()

View File

@ -10,7 +10,7 @@ from torch import nn
from .build import META_ARCH_REGISTRY
from ..backbones import build_backbone
from ..heads import build_reid_heads
from ..layers import GeneralizedMeanPoolingP
from ...layers import GeneralizedMeanPoolingP
from ..losses import reid_losses

View File

@ -13,7 +13,7 @@ from ..backbones import build_backbone
from ..backbones.resnet import Bottleneck
from ..heads import build_reid_heads
from ..model_utils import weights_init_kaiming
from fastreid.modeling.layers import BatchDrop, Flatten, GeneralizedMeanPoolingP
from fastreid.layers import BatchDrop, Flatten, GeneralizedMeanPoolingP
@META_ARCH_REGISTRY.register()

View File

@ -9,7 +9,7 @@ from torch import nn
import torch.nn.functional as F
from fastreid.modeling.model_utils import *
from fastreid.modeling.layers import NoBiasBatchNorm1d
from fastreid.layers import NoBiasBatchNorm1d
class MaskUnit(nn.Module):

View File

@ -12,7 +12,7 @@ from .build import META_ARCH_REGISTRY
from ..model_utils import weights_init_kaiming
from ..backbones import build_backbone
from ..heads import build_reid_heads
from fastreid.modeling.layers import Flatten
from fastreid.layers import Flatten
@META_ARCH_REGISTRY.register()

View File

@ -13,7 +13,7 @@ from ..backbones import build_backbone
from ..backbones.resnet import Bottleneck
from ..heads import build_reid_heads
from ..model_utils import weights_init_kaiming
from fastreid.modeling.layers import GeneralizedMeanPoolingP, Flatten
from ...layers import GeneralizedMeanPoolingP, Flatten
@META_ARCH_REGISTRY.register()
@ -204,5 +204,3 @@ class MGN(nn.Module):
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
return nn.functional.normalize(pred_feat)

View File

@ -12,7 +12,7 @@ from .build import META_ARCH_REGISTRY
from ..backbones import build_backbone
from ..heads import build_reid_heads
from ..model_utils import weights_init_kaiming
from fastreid.modeling.layers import Flatten
from fastreid.layers import Flatten
@META_ARCH_REGISTRY.register()

View File

@ -6,4 +6,4 @@ from .radam import RAdam, PlainRAdam, AdamW
from .ralamb import Ralamb
from .ranger import Ranger
from torch.optim import *
from torch.optim import *

View File

@ -3,28 +3,25 @@
####
import collections
import math
import torch
from torch.optim.optimizer import Optimizer
from torch.utils.tensorboard import SummaryWriter
try:
from tensorboardX import SummaryWriter
def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
"""Log a histogram of trust ratio scalars in across layers."""
results = collections.defaultdict(list)
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
if i in state:
results[i].append(state[i])
def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
"""Log a histogram of trust ratio scalars in across layers."""
results = collections.defaultdict(list)
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
if i in state:
results[i].append(state[i])
for k, v in results.items():
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
for k, v in results.items():
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
except ModuleNotFoundError as e:
print("To use this log_lamb_rs, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")
class Lamb(Optimizer):
r"""Implements Lamb algorithm.
@ -102,7 +99,7 @@ class Lamb(Optimizer):
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
@ -123,4 +120,4 @@ class Lamb(Optimizer):
p.data.add_(-step_size * trust_ratio, adam_step)
return loss
return loss

View File

@ -1,5 +1,6 @@
# AGW Baseline in FastReID
## Training
To train a model, run

View File

@ -4,6 +4,13 @@ MODEL:
HEADS:
NUM_CLASSES: 702
SOLVER:
MAX_ITER: 23000
STEPS: [10000, 18000]
WARMUP_ITERS: 2500
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)

View File

@ -4,6 +4,13 @@ MODEL:
HEADS:
NUM_CLASSES: 751
SOLVER:
MAX_ITER: 18000
STEPS: [8000, 14000]
WARMUP_ITERS: 2000
DATASETS:
NAMES: ("Market1501",)
TESTS: ("Market1501",)

View File

@ -9,10 +9,10 @@ DATASETS:
TESTS: ("MSMT17",)
SOLVER:
MAX_ITER: 45000
STEPS: [20000, 35000]
MAX_ITER: 42000
STEPS: [19000, 33000]
WARMUP_ITERS: 4700
LOG_PERIOD: 500
CHECKPOINT_PERIOD: 5000
TEST:

View File

@ -2,23 +2,29 @@ MODEL:
META_ARCHITECTURE: 'Baseline'
BACKBONE:
NAME: "build_resnetNL_backbone"
NAME: "build_resnet_backbone"
DEPTH: 50
LAST_STRIDE: 1
WITH_IBN: False
WITH_NL: True
PRETRAIN: True
HEADS:
NAME: "GemHead"
NAME: "BNneckHead"
POOL_LAYER: "gempool"
CLS_LAYER: "linear"
NUM_CLASSES: 702
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss")
SMOOTH_ON: True
SCALE_CE: 1.0
CE:
EPSILON: 0.1
SCALE: 1.0
MARGIN: 0.0
SCALE_TRI: 1.0
TRI:
MARGIN: 0.0
HARD_MINING: False
USE_COSINE_DIST: False
SCALE: 1.0
DATASETS:
NAMES: ("DukeMTMC",)
@ -27,15 +33,12 @@ DATASETS:
INPUT:
SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128]
RE:
REA:
ENABLED: True
PROB: 0.5
CUTOUT:
ENABLED: False
MEAN: [123.675, 116.28, 103.53]
DO_PAD: True
DO_LIGHTING: False
DATALOADER:
PK_SAMPLER: True
NUM_INSTANCE: 4
@ -64,4 +67,4 @@ TEST:
CUDNN_BENCHMARK: True
OUTPUT_DIR: "logs/fastreid_dukemtmc/ibn_softmax_softtriplet"
OUTPUT_DIR: "logs"

View File

@ -0,0 +1,3 @@
gpus='1'
CUDA_VISIBLE_DEVICES=$gpus python train_net.py --config-file 'configs/AGW_dukemtmc.yml'

View File

@ -0,0 +1,3 @@
gpus='0'
CUDA_VISIBLE_DEVICES=$gpus python train_net.py --config-file 'configs/AGW_market1501.yml'

View File

@ -0,0 +1,3 @@
gpus='3'
CUDA_VISIBLE_DEVICES=$gpus python train_net.py --config-file 'configs/AGW_msmt17.yml'

View File

@ -13,8 +13,6 @@ from fastreid.engine import DefaultTrainer, default_argument_parser, default_set
from fastreid.utils.checkpoint import Checkpointer
from fastreid.evaluation import ReidEvaluator
from agwbaseline import *
class Trainer(DefaultTrainer):
@classmethod
@ -23,6 +21,7 @@ class Trainer(DefaultTrainer):
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return ReidEvaluator(cfg, num_query)
def setup(args):
"""
Create configs and perform basic setups.
@ -59,4 +58,3 @@ if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
main(args)

View File

@ -1,9 +0,0 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
from .gem_pool import GemHead
from .resnet_nl import build_resnetNL_backbone
from .wr_triplet_loss import WeightedRegularizedTriplet

View File

@ -1,84 +0,0 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from torch import nn
from fastreid.modeling.model_utils import weights_init_kaiming, weights_init_classifier
from fastreid.modeling.heads import REID_HEADS_REGISTRY
from fastreid.layers import bn_no_bias
from fastreid.modeling.heads import StandardHead
class GeneralizedMeanPooling(nn.Module):
r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
- At p = infinity, one gets Max Pooling
- At p = 1, one gets Average Pooling
The output is of size H x W, for any input size.
The number of output features is equal to the number of input planes.
Args:
output_size: the target output size of the image of the form H x W.
Can be a tuple (H, W) or a single H for a square image H x H
H and W can be either a ``int``, or ``None`` which means the size will
be the same as that of the input.
"""
def __init__(self, norm, output_size=1, eps=1e-6):
super(GeneralizedMeanPooling, self).__init__()
assert norm > 0
self.p = float(norm)
self.output_size = output_size
self.eps = eps
def forward(self, x):
x = x.clamp(min=self.eps).pow(self.p)
return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ str(self.p) + ', ' \
+ 'output_size=' + str(self.output_size) + ')'
class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
""" Same, but norm is trainable
"""
def __init__(self, norm=3, output_size=1, eps=1e-6):
super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
self.p = nn.Parameter(torch.ones(1) * norm)
@REID_HEADS_REGISTRY.register()
class GemHead(nn.Module):
def __init__(self, cfg):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.gem_pool = GeneralizedMeanPoolingP()
self.bnneck = bn_no_bias(2048)
self.bnneck.apply(weights_init_kaiming)
self.classifier = nn.Linear(2048, self._num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
global_features = self.gem_pool(features)
global_features = global_features.view(global_features.shape[0], -1)
bn_features = self.bnneck(global_features)
if not self.training:
return F.normalize(bn_features)
pred_class_logits = self.classifier(bn_features)
return pred_class_logits, global_features, targets
@classmethod
def losses(cls, cfg, pred_class_logits, global_features, gt_classes):
return StandardHead.losses(cfg, pred_class_logits, global_features, gt_classes)

View File

@ -1,159 +0,0 @@
# encoding: utf-8
import logging
import math
import torch
from torch import nn
from fastreid.modeling.backbones import BACKBONE_REGISTRY
from fastreid.modeling.backbones.resnet import Bottleneck, model_zoo, model_urls
from .non_local_layer import Non_local
class ResNetNL(nn.Module):
def __init__(self, last_stride, with_ibn, block=Bottleneck, layers=[3, 4, 6, 3], non_layers=[0, 2, 3, 0]):
self.inplanes = 64
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=last_stride)
self.NL_1 = nn.ModuleList(
[Non_local(256) for i in range(non_layers[0])])
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
self.NL_2 = nn.ModuleList(
[Non_local(512) for i in range(non_layers[1])])
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
self.NL_3 = nn.ModuleList(
[Non_local(1024) for i in range(non_layers[2])])
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
self.NL_4 = nn.ModuleList(
[Non_local(2048) for i in range(non_layers[3])])
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
if planes == 512:
with_ibn = False
layers.append(block(self.inplanes, planes, with_ibn, stride=stride, downsample=downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, with_ibn))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
NL1_counter = 0
if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1]
for i in range(len(self.layer1)):
x = self.layer1[i](x)
if i == self.NL_1_idx[NL1_counter]:
_, C, H, W = x.shape
x = self.NL_1[NL1_counter](x)
NL1_counter += 1
# Layer 2
NL2_counter = 0
if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1]
for i in range(len(self.layer2)):
x = self.layer2[i](x)
if i == self.NL_2_idx[NL2_counter]:
_, C, H, W = x.shape
x = self.NL_2[NL2_counter](x)
NL2_counter += 1
# Layer 3
NL3_counter = 0
if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1]
for i in range(len(self.layer3)):
x = self.layer3[i](x)
if i == self.NL_3_idx[NL3_counter]:
_, C, H, W = x.shape
x = self.NL_3[NL3_counter](x)
NL3_counter += 1
# Layer 4
NL4_counter = 0
if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1]
for i in range(len(self.layer4)):
x = self.layer4[i](x)
if i == self.NL_4_idx[NL4_counter]:
_, C, H, W = x.shape
x = self.NL_4[NL4_counter](x)
NL4_counter += 1
return x
def load_param(self, model_path):
param_dict = torch.load(model_path)
for i in param_dict:
if 'fc' in i:
continue
self.state_dict()[i].copy_(param_dict[i])
def random_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
@BACKBONE_REGISTRY.register()
def build_resnetNL_backbone(cfg):
"""
Create a ResNet Non-local instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
nl_layers_per_stage = [0, 2, 3, 0]
model = ResNetNL(last_stride, with_ibn, Bottleneck, num_blocks_per_stage, nl_layers_per_stage)
if pretrain:
if not with_ibn:
# original resnet
state_dict = model_zoo.load_url(model_urls[depth])
else:
# ibn resnet
state_dict = torch.load(pretrain_path)['state_dict']
# remove module in name
new_state_dict = {}
for k in state_dict:
new_k = '.'.join(k.split('.')[1:])
if model.state_dict()[new_k].shape == state_dict[k].shape:
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger('fastreid.'+__name__)
logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
return model

View File

@ -1,50 +0,0 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from fastreid.modeling.losses.margin_loss import normalize, euclidean_dist
def softmax_weights(dist, mask):
max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
diff = dist - max_v
Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
W = torch.exp(diff) * mask / Z
return W
class WeightedRegularizedTriplet(object):
def __init__(self, cfg):
self.ranking_loss = nn.SoftMarginLoss()
self._normalize_feature = False
def __call__(self, pred_class_logits, global_feat, labels):
if self._normalize_feature:
global_feat = normalize(global_feat, axis=-1)
dist_mat = euclidean_dist(global_feat, global_feat)
N = dist_mat.size(0)
# shape [N, N]
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()).float()
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()).float()
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap = dist_mat * is_pos
dist_an = dist_mat * is_neg
weights_ap = softmax_weights(dist_ap, is_pos)
weights_an = softmax_weights(-dist_an, is_neg)
furthest_positive = torch.sum(dist_ap * weights_ap, dim=1)
closest_negative = torch.sum(dist_an * weights_an, dim=1)
y = furthest_positive.new().resize_as_(furthest_positive).fill_(1)
loss = self.ranking_loss(closest_negative - furthest_positive, y)
return {
"loss_wrTriplet": loss,
}

View File

@ -0,0 +1,83 @@
# Bag of Tricks and A Strong ReID Baseline in FastReID
Bag of Tricks and A Strong Baseline for Deep Person Re-identification. CVPRW2019, Oral.
[Hao Luo\*](https://github.com/michuanhaohao) [Youzhi Gu\*](https://github.com/shaoniangu) [Xingyu Liao\*](https://github.com/L1aoXingyu) [Shenqi Lai](https://github.com/xiaolai-sqlai)
A Strong Baseline and Batch Normalization Neck for Deep Person Re-identification. IEEE Transactions on Multimedia (Accepted).
[[Journal Version(TMM)]](https://ieeexplore.ieee.org/document/8930088)
[[PDF]](http://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf)
[[Slides]](https://drive.google.com/open?id=1h9SgdJenvfoNp9PTUxPiz5_K5HFCho-V)
[[Poster]](https://drive.google.com/open?id=1izZYAwylBsrldxSMqHCH432P6hnyh1vR)
## Training
To train a model, run
```bash
CUDA_VISIBLE_DEVICES=gpus python train_net.py --config-file <config.yml>
```
For example, to launch a end-to-end baseline training on market1501 dataset on GPU#1,
one should excute:
```bash
CUDA_VISIBLE_DEVICES=1 python train_net.py --config-file='configs/bagtricks_market1501.yml'
```
## Evaluation
To evaluate the model in test set, run similarly:
```bash
CUDA_VISIBLE_DEVICES=gpus python train_net.py --config-file <configs.yaml> --eval-only MODEL.WEIGHTS model.pth
```
## Experimental Results
You can reproduce the results by simply excute
```bash
sh scripts/train_market.sh
sh scripts/train_duke.sh
sh scripts/train_msmt.sh
```
### Market1501 dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 93.9% | 84.9% | 57.1% |
### DukeMTMC dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 87.1% | 76.4% | 39.2% |
### MSMT17 dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 72.2% | 48.4% | 9.6% |
```
@InProceedings{Luo_2019_CVPR_Workshops,
author = {Luo, Hao and Gu, Youzhi and Liao, Xingyu and Lai, Shenqi and Jiang, Wei},
title = {Bag of Tricks and a Strong Baseline for Deep Person Re-Identification},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {June},
year = {2019}
}
@ARTICLE{Luo_2019_Strong_TMM,
author={H. {Luo} and W. {Jiang} and Y. {Gu} and F. {Liu} and X. {Liao} and S. {Lai} and J. {Gu}},
journal={IEEE Transactions on Multimedia},
title={A Strong Baseline and Batch Normalization Neck for Deep Person Re-identification},
year={2019},
pages={1-1},
doi={10.1109/TMM.2019.2958756},
ISSN={1941-0077},
}
```

View File

@ -1,23 +1,30 @@
MODEL:
META_ARCHITECTURE: 'Baseline'
OPEN_LAYERS: ""
BACKBONE:
NAME: "build_resnet_backbone"
DEPTH: 50
LAST_STRIDE: 1
WITH_IBN: False
PRETRAIN: True
HEADS:
NAME: "BNneckHead"
CLS_LAYER: "linear"
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss")
SMOOTH_ON: True
SCALE_CE: 1.0
MARGIN: 0.3
SCALE_TRI: 1.0
CE:
EPSILON: 0.1
SCALE: 1.0
TRI:
MARGIN: 0.3
HARD_MINING: True
USE_COSINE_DIST: False
SCALE: 1.0
DATASETS:
NAMES: ("DukeMTMC",)
@ -26,14 +33,12 @@ DATASETS:
INPUT:
SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128]
RE:
REA:
ENABLED: True
PROB: 0.5
CUTOUT:
ENABLED: False
MEAN: [123.675, 116.28, 103.53]
DO_PAD: True
DO_LIGHTING: False
DATALOADER:
PK_SAMPLER: True
@ -44,9 +49,9 @@ SOLVER:
OPT: "Adam"
MAX_ITER: 18000
BASE_LR: 0.00035
BIAS_LR_FACTOR: 2
BIAS_LR_FACTOR: 2.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0
WEIGHT_DECAY_BIAS: 0.
IMS_PER_BATCH: 64
STEPS: [8000, 14000]
@ -56,7 +61,7 @@ SOLVER:
WARMUP_ITERS: 2000
LOG_PERIOD: 200
CHECKPOINT_PERIOD: 6000
CHECKPOINT_PERIOD: 2000
TEST:
EVAL_PERIOD: 2000

View File

@ -0,0 +1,18 @@
_BASE_: "Base-bagtricks.yml"
MODEL:
HEADS:
NUM_CLASSES: 702
SOLVER:
MAX_ITER: 23000
STEPS: [10000, 18000]
WARMUP_ITERS: 2500
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
OUTPUT_DIR: "logs/dukemtmc/bagtricks"

View File

@ -1,9 +1,16 @@
_BASE_: "Base-Strongbaseline.yml"
_BASE_: "Base-bagtricks.yml"
MODEL:
HEADS:
NUM_CLASSES: 751
SOLVER:
MAX_ITER: 18000
STEPS: [8000, 14000]
WARMUP_ITERS: 2000
DATASETS:
NAMES: ("Market1501",)
TESTS: ("Market1501",)

View File

@ -0,0 +1,21 @@
_BASE_: "Base-bagtricks.yml"
MODEL:
HEADS:
NUM_CLASSES: 1041
DATASETS:
NAMES: ("MSMT17",)
TESTS: ("MSMT17",)
SOLVER:
MAX_ITER: 42000
STEPS: [19000, 33000]
WARMUP_ITERS: 4700
CHECKPOINT_PERIOD: 5000
TEST:
EVAL_PERIOD: 5000
OUTPUT_DIR: "logs/msmt17/bagtricks"

View File

@ -0,0 +1,48 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import sys
import torch
sys.path.append('../..')
from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup
from fastreid.modeling.meta_arch import build_model
from fastreid.export.tensorflow_export import export_tf_reid_model
from fastreid.export.tf_modeling import TfMetaArch
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
# cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
cfg = setup(args)
cfg.defrost()
cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
cfg.MODEL.BACKBONE.DEPTH = 50
cfg.MODEL.BACKBONE.LAST_STRIDE = 1
# If use IBN block in backbone
cfg.MODEL.BACKBONE.WITH_IBN = False
cfg.MODEL.BACKBONE.PRETRAIN = False
from torchvision.models import resnet50
# model = TfMetaArch(cfg)
model = resnet50(pretrained=False)
# model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
model.eval()
dummy_inputs = torch.randn(1, 3, 256, 128)
export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')

View File

@ -0,0 +1,3 @@
gpus='1'
CUDA_VISIBLE_DEVICES=$gpus python train_net.py --config-file 'configs/bagtricks_dukemtmc.yml'

View File

@ -0,0 +1,3 @@
gpus='0'
CUDA_VISIBLE_DEVICES=$gpus python train_net.py --config-file 'configs/bagtricks_market1501.yml'

View File

@ -0,0 +1,3 @@
gpus='2'
CUDA_VISIBLE_DEVICES=$gpus python train_net.py --config-file 'configs/bagtricks_msmt17.yml'

View File

@ -14,7 +14,6 @@ from fastreid.config import get_cfg
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
from fastreid.utils.checkpoint import Checkpointer
from fastreid.evaluation import ReidEvaluator
from reduce_head import ReduceHead
class Trainer(DefaultTrainer):
@ -46,9 +45,7 @@ def main(args):
model = Trainer.build_model(cfg)
model = nn.DataParallel(model)
model = model.cuda()
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).load(cfg.MODEL.WEIGHTS)
res = Trainer.test(cfg, model)
return res

View File

@ -1,39 +0,0 @@
# Strong Baseline in FastReID
## Training
To train a model, run
```bash
CUDA_VISIBLE_DEVICES=gpus python train_net.py --config-file <config.yaml>
```
For example, to launch a end-to-end baseline training on market1501 dataset with ibn-net on 4 GPUs,
one should excute:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_net.py --config-file='configs/baseline_ibn_market1501.yml'
```
## Experimental Results
### Market1501 dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 93.6% | 85.1% | 58.1% |
| BagTricks + Ibn-a | ImageNet | 94.8% | 87.3% | 63.5% |
### DukeMTMC dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 86.1% | 75.9% | 38.7% |
| BagTricks + Ibn-a | ImageNet | 89.0% | 78.8% | 43.6% |
### MSMT17 dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 70.4% | 47.5% | 9.6% |
| BagTricks + Ibn-a | ImageNet | 76.9% | 55.0% | 13.5% |

View File

@ -1,7 +0,0 @@
_BASE_: "Base-Strongbaseline.yml"
MODEL:
BACKBONE:
WITH_IBN: True
PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"

View File

@ -1,43 +0,0 @@
_BASE_: "Base-Strongbaseline.yml"
MODEL:
META_ARCHITECTURE: "MGN_v2"
HEADS:
POOL_LAYER: "maxpool"
NAME: "StandardHead"
NUM_CLASSES: 702
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss")
SMOOTH_ON: True
SCALE_CE: 0.1
MARGIN: 0.3
SCALE_TRI: 0.167
INPUT:
RE:
ENABLED: True
PROB: 0.5
CUTOUT:
ENABLED: False
SOLVER:
MAX_ITER: 9000
BASE_LR: 0.00035
BIAS_LR_FACTOR: 2
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0
IMS_PER_BATCH: 256
STEPS: [4000, 7000]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 1000
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
OUTPUT_DIR: "logs/dukemtmc/mgn_v2"

View File

@ -1,11 +0,0 @@
_BASE_: "Base-Strongbaseline_ibn.yml"
MODEL:
HEADS:
NUM_CLASSES: 702
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
OUTPUT_DIR: "logs/dukemtmc/ibn_bagtricks"

View File

@ -1,11 +0,0 @@
_BASE_: "Base-Strongbaseline_ibn.yml"
MODEL:
HEADS:
NUM_CLASSES: 751
DATASETS:
NAMES: ("Market1501",)
TESTS: ("Market1501",)
OUTPUT_DIR: "logs/market1501/ibn_bagtricks"

View File

@ -1,22 +0,0 @@
_BASE_: "Base-Strongbaseline_ibn.yml"
MODEL:
HEADS:
NUM_CLASSES: 1041
DATASETS:
NAMES: ("MSMT17",)
TESTS: ("MSMT17",)
SOLVER:
MAX_ITER: 45000
STEPS: [20000, 35000]
WARMUP_ITERS: 2000
LOG_PERIOD: 500
CHECKPOINT_PERIOD: 15000
TEST:
EVAL_PERIOD: 15000
OUTPUT_DIR: "logs/msmt17/ibn_bagtricks"

View File

@ -1,22 +0,0 @@
_BASE_: "Base-Strongbaseline.yml"
MODEL:
HEADS:
NUM_CLASSES: 1041
DATASETS:
NAMES: ("MSMT17",)
TESTS: ("MSMT17",)
SOLVER:
MAX_ITER: 45000
STEPS: [20000, 35000]
WARMUP_ITERS: 2000
LOG_PERIOD: 500
CHECKPOINT_PERIOD: 15000
TEST:
EVAL_PERIOD: 15000
OUTPUT_DIR: "logs/msmt17/bagtricks"