mirror of https://github.com/JDAI-CV/fast-reid.git
实现和线上一模一样的pcb
parent
eabe896ee8
commit
44a90ae025
|
@ -4,14 +4,14 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from se_pcb_net import build_senet_pcb_backbone
|
||||
from .build import build_backbone, BACKBONE_REGISTRY
|
||||
|
||||
from .resnet import build_resnet_backbone
|
||||
from .osnet import build_osnet_backbone
|
||||
from .resnest import build_resnest_backbone
|
||||
from .resnext import build_resnext_backbone
|
||||
from .regnet import build_regnet_backbone, build_effnet_backbone
|
||||
from .shufflenet import build_shufflenetv2_backbone
|
||||
from .mobilenet import build_mobilenetv2_backbone
|
||||
from .osnet import build_osnet_backbone
|
||||
from .regnet import build_regnet_backbone, build_effnet_backbone
|
||||
from .repvgg import build_repvgg_backbone
|
||||
from .resnest import build_resnest_backbone
|
||||
from .resnet import build_resnet_backbone
|
||||
from .resnext import build_resnext_backbone
|
||||
from .shufflenet import build_shufflenetv2_backbone
|
||||
from .vision_transformer import build_vit_backbone
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/10/28 23:06:51
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : senet.py
|
||||
import logging
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import pretrainedmodels
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.config.config import CfgNode
|
||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SePcbNet(nn.Module):
|
||||
def __init__(self,
|
||||
part_num: int,
|
||||
embedding_dim: int,
|
||||
part_dim: int,
|
||||
last_stride: Tuple[int, int]
|
||||
):
|
||||
super(SePcbNet, self).__init__()
|
||||
self.part_num = part_num
|
||||
self.embedding_dim = embedding_dim
|
||||
self.part_dim = part_dim
|
||||
self.last_stride = last_stride
|
||||
|
||||
self.cnn = pretrainedmodels.__dict__["se_resnext101_32x4d"](pretrained='imagenet')
|
||||
self.cnn.layer4[0].downsample[0].stride = self.last_stride
|
||||
self.cnn.layer4[0].conv2.stride = self.last_stride
|
||||
self.cnn.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.avg_pool_part6 = nn.AdaptiveAvgPool2d((self.part_num, 1))
|
||||
|
||||
for i in range(self.part_num):
|
||||
setattr(self, 'reduction_' + str(i),
|
||||
nn.Sequential(
|
||||
nn.Conv2d(self.embedding_dim, self.part_dim, (1, 1), bias=False),
|
||||
nn.BatchNorm2d(self.part_num),
|
||||
nn.ReLU()
|
||||
))
|
||||
|
||||
self.random_init()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cnn.layer0(x)
|
||||
x = self.cnn.layer1(x)
|
||||
x = self.cnn.layer2(x)
|
||||
x = self.cnn.layer3(x)
|
||||
x = self.cnn.layer4(x)
|
||||
|
||||
x_full = self.cnn.avg_pool(x)
|
||||
x_full = x_full.reshape(x_full.shape[0], -1)
|
||||
|
||||
x_part = self.avg_pool_part6(x)
|
||||
parts = []
|
||||
|
||||
for i in range(self.part_num):
|
||||
reduction_i = getattr(self, 'reduction_' + str(i))
|
||||
part_i = x_part[:, :, i: i + 1, :]
|
||||
parts.append(reduction_i(part_i).squeeze())
|
||||
|
||||
return {
|
||||
'full': x_full,
|
||||
'parts': parts,
|
||||
}
|
||||
|
||||
def random_init(self):
|
||||
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.InstanceNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_senet_pcb_backbone(cfg: CfgNode):
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
part_num = cfg.MODEL.PCB.PART_NUM
|
||||
part_dim = cfg.MODEL.PCB.PART_DIM
|
||||
embedding_dim = cfg.MODEL.PCB.EMBEDDING_DIM
|
||||
# fmt: on
|
||||
|
||||
model = SePcbNet(part_num=part_num, embedding_dim=embedding_dim, part_dim=part_dim, last_stride=last_stride)
|
||||
|
||||
if pretrain:
|
||||
if pretrain_path:
|
||||
try:
|
||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
new_k = '.'.join(k.split('.')[2:])
|
||||
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
new_state_dict[new_k] = state_dict[k]
|
||||
state_dict = new_state_dict
|
||||
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f'{pretrain_path} is not found! Please check this path.')
|
||||
raise e
|
||||
except KeyError as e:
|
||||
logger.error("State dict keys error! Please check the state dict.")
|
||||
raise e
|
||||
else:
|
||||
logger.info('Not config pretrained mode with SePcbNet, the weights will be random init')
|
||||
|
||||
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||
if incompatible.missing_keys:
|
||||
logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
)
|
||||
if incompatible.unexpected_keys:
|
||||
logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
)
|
||||
|
||||
return model
|
|
@ -1,15 +1,15 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.config import configurable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastreid.config import CfgNode
|
||||
from fastreid.config import configurable
|
||||
from fastreid.layers import weights_init_classifier
|
||||
from fastreid.layers import any_softmax
|
||||
from fastreid.modeling.heads import REID_HEADS_REGISTRY, EmbeddingHead
|
||||
from fastreid.modeling.heads import REID_HEADS_REGISTRY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,10 +45,6 @@ class PcbHead(nn.Module):
|
|||
self.full_dim = full_dim
|
||||
self.part_dim = part_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
# self.num_classes = num_classes
|
||||
# self.cls_type = cls_type
|
||||
# self.scale = scale
|
||||
# self.margin = margin
|
||||
|
||||
self.match_full = nn.Sequential(
|
||||
nn.Dropout(p=0.5),
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# coding: utf-8
|
||||
"""
|
||||
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
|
||||
实现和线上一模一样的PCB
|
||||
"""
|
||||
import torch
|
||||
|
||||
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
|
||||
from fastreid.modeling.meta_arch import Baseline
|
||||
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class PcbOnline(Baseline):
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
# TODO: 减均值除方差
|
||||
|
||||
q = batched_inputs['q_img']
|
||||
x = batched_inputs['p_img']
|
||||
qf = self.cnn(q)
|
||||
xf = self.cnn(x)
|
||||
# L2 norm TODO: 查看具体数值,确定他到底是什么范数
|
||||
qf = self._norm(qf)
|
||||
xf = self._norm(xf)
|
||||
|
||||
|
||||
|
||||
qf_full, qf_part_0, qf_part_1, qf_part_2 = torch.split(qf, [2048, 512, 512, 512], dim=-1)
|
||||
xf_full, xf_part_0, xf_part_1, xf_part_2 = torch.split(xf, [2048, 512, 512, 512], dim=-1)
|
||||
|
||||
m_full = self.dense_matching_full(
|
||||
torch.cat([qf_full, xf_full, (qf_full - xf_full).abs(), qf_full * xf_full], dim=-1))
|
||||
m_part_0 = self.dense_matching_part_0(
|
||||
torch.cat([qf_part_0, xf_part_0, (qf_part_0 - xf_part_0).abs(), qf_part_0 * xf_part_0], dim=-1))
|
||||
m_part_1 = self.dense_matching_part_1(
|
||||
torch.cat([qf_part_1, xf_part_1, (qf_part_1 - xf_part_1).abs(), qf_part_1 * xf_part_1], dim=-1))
|
||||
m_part_2 = self.dense_matching_part_2(
|
||||
torch.cat([qf_part_2, xf_part_2, (qf_part_2 - xf_part_2).abs(), qf_part_2 * xf_part_2], dim=-1))
|
||||
|
||||
m_all = self.dense_matching_all(torch.cat([m_full, m_part_0, m_part_1, m_part_2], dim=-1))
|
||||
|
||||
return m_all.squeeze()
|
||||
|
||||
def losses(self, outputs, gt_labels):
|
||||
"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
# model predictions
|
||||
pred_class_logits = outputs['pred_class_logits'].detach()
|
||||
cls_outputs = outputs['cls_outputs']
|
||||
|
||||
# Log prediction accuracy
|
||||
log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
||||
if 'CrossEntropyLoss' in loss_names:
|
||||
ce_kwargs = self.loss_kwargs.get('ce')
|
||||
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale')
|
||||
|
||||
|
||||
return loss_dict
|
Loading…
Reference in New Issue