mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
实现和线上一模一样的pcb
This commit is contained in:
parent
eabe896ee8
commit
44a90ae025
@ -4,14 +4,14 @@
|
|||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from se_pcb_net import build_senet_pcb_backbone
|
||||||
from .build import build_backbone, BACKBONE_REGISTRY
|
from .build import build_backbone, BACKBONE_REGISTRY
|
||||||
|
|
||||||
from .resnet import build_resnet_backbone
|
|
||||||
from .osnet import build_osnet_backbone
|
|
||||||
from .resnest import build_resnest_backbone
|
|
||||||
from .resnext import build_resnext_backbone
|
|
||||||
from .regnet import build_regnet_backbone, build_effnet_backbone
|
|
||||||
from .shufflenet import build_shufflenetv2_backbone
|
|
||||||
from .mobilenet import build_mobilenetv2_backbone
|
from .mobilenet import build_mobilenetv2_backbone
|
||||||
|
from .osnet import build_osnet_backbone
|
||||||
|
from .regnet import build_regnet_backbone, build_effnet_backbone
|
||||||
from .repvgg import build_repvgg_backbone
|
from .repvgg import build_repvgg_backbone
|
||||||
|
from .resnest import build_resnest_backbone
|
||||||
|
from .resnet import build_resnet_backbone
|
||||||
|
from .resnext import build_resnext_backbone
|
||||||
|
from .shufflenet import build_shufflenetv2_backbone
|
||||||
from .vision_transformer import build_vit_backbone
|
from .vision_transformer import build_vit_backbone
|
||||||
|
129
fastreid/modeling/backbones/se_pcb_net.py
Normal file
129
fastreid/modeling/backbones/se_pcb_net.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Time : 2021/10/28 23:06:51
|
||||||
|
# @Author : zuchen.wang@vipshop.com
|
||||||
|
# @File : senet.py
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pretrainedmodels
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from fastreid.config.config import CfgNode
|
||||||
|
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||||
|
from .build import BACKBONE_REGISTRY
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SePcbNet(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
part_num: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
part_dim: int,
|
||||||
|
last_stride: Tuple[int, int]
|
||||||
|
):
|
||||||
|
super(SePcbNet, self).__init__()
|
||||||
|
self.part_num = part_num
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.part_dim = part_dim
|
||||||
|
self.last_stride = last_stride
|
||||||
|
|
||||||
|
self.cnn = pretrainedmodels.__dict__["se_resnext101_32x4d"](pretrained='imagenet')
|
||||||
|
self.cnn.layer4[0].downsample[0].stride = self.last_stride
|
||||||
|
self.cnn.layer4[0].conv2.stride = self.last_stride
|
||||||
|
self.cnn.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.avg_pool_part6 = nn.AdaptiveAvgPool2d((self.part_num, 1))
|
||||||
|
|
||||||
|
for i in range(self.part_num):
|
||||||
|
setattr(self, 'reduction_' + str(i),
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Conv2d(self.embedding_dim, self.part_dim, (1, 1), bias=False),
|
||||||
|
nn.BatchNorm2d(self.part_num),
|
||||||
|
nn.ReLU()
|
||||||
|
))
|
||||||
|
|
||||||
|
self.random_init()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.cnn.layer0(x)
|
||||||
|
x = self.cnn.layer1(x)
|
||||||
|
x = self.cnn.layer2(x)
|
||||||
|
x = self.cnn.layer3(x)
|
||||||
|
x = self.cnn.layer4(x)
|
||||||
|
|
||||||
|
x_full = self.cnn.avg_pool(x)
|
||||||
|
x_full = x_full.reshape(x_full.shape[0], -1)
|
||||||
|
|
||||||
|
x_part = self.avg_pool_part6(x)
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
for i in range(self.part_num):
|
||||||
|
reduction_i = getattr(self, 'reduction_' + str(i))
|
||||||
|
part_i = x_part[:, :, i: i + 1, :]
|
||||||
|
parts.append(reduction_i(part_i).squeeze())
|
||||||
|
|
||||||
|
return {
|
||||||
|
'full': x_full,
|
||||||
|
'parts': parts,
|
||||||
|
}
|
||||||
|
|
||||||
|
def random_init(self):
|
||||||
|
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.InstanceNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONE_REGISTRY.register()
|
||||||
|
def build_senet_pcb_backbone(cfg: CfgNode):
|
||||||
|
# fmt: off
|
||||||
|
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||||
|
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||||
|
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||||
|
part_num = cfg.MODEL.PCB.PART_NUM
|
||||||
|
part_dim = cfg.MODEL.PCB.PART_DIM
|
||||||
|
embedding_dim = cfg.MODEL.PCB.EMBEDDING_DIM
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
model = SePcbNet(part_num=part_num, embedding_dim=embedding_dim, part_dim=part_dim, last_stride=last_stride)
|
||||||
|
|
||||||
|
if pretrain:
|
||||||
|
if pretrain_path:
|
||||||
|
try:
|
||||||
|
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
|
||||||
|
new_state_dict = {}
|
||||||
|
for k in state_dict:
|
||||||
|
new_k = '.'.join(k.split('.')[2:])
|
||||||
|
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||||
|
new_state_dict[new_k] = state_dict[k]
|
||||||
|
state_dict = new_state_dict
|
||||||
|
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error(f'{pretrain_path} is not found! Please check this path.')
|
||||||
|
raise e
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error("State dict keys error! Please check the state dict.")
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
logger.info('Not config pretrained mode with SePcbNet, the weights will be random init')
|
||||||
|
|
||||||
|
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||||
|
if incompatible.missing_keys:
|
||||||
|
logger.info(
|
||||||
|
get_missing_parameters_message(incompatible.missing_keys)
|
||||||
|
)
|
||||||
|
if incompatible.unexpected_keys:
|
||||||
|
logger.info(
|
||||||
|
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
@ -1,15 +1,15 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from fastreid.config import configurable
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from fastreid.config import CfgNode
|
from fastreid.config import CfgNode
|
||||||
|
from fastreid.config import configurable
|
||||||
from fastreid.layers import weights_init_classifier
|
from fastreid.layers import weights_init_classifier
|
||||||
from fastreid.layers import any_softmax
|
from fastreid.modeling.heads import REID_HEADS_REGISTRY
|
||||||
from fastreid.modeling.heads import REID_HEADS_REGISTRY, EmbeddingHead
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -45,10 +45,6 @@ class PcbHead(nn.Module):
|
|||||||
self.full_dim = full_dim
|
self.full_dim = full_dim
|
||||||
self.part_dim = part_dim
|
self.part_dim = part_dim
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
# self.num_classes = num_classes
|
|
||||||
# self.cls_type = cls_type
|
|
||||||
# self.scale = scale
|
|
||||||
# self.margin = margin
|
|
||||||
|
|
||||||
self.match_full = nn.Sequential(
|
self.match_full = nn.Sequential(
|
||||||
nn.Dropout(p=0.5),
|
nn.Dropout(p=0.5),
|
||||||
|
70
fastreid/modeling/meta_arch/pcb_oneline.py
Normal file
70
fastreid/modeling/meta_arch/pcb_oneline.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
"""
|
||||||
|
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
|
||||||
|
实现和线上一模一样的PCB
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
|
||||||
|
from fastreid.modeling.meta_arch import Baseline
|
||||||
|
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
|
@META_ARCH_REGISTRY.register()
|
||||||
|
class PcbOnline(Baseline):
|
||||||
|
|
||||||
|
def forward(self, batched_inputs):
|
||||||
|
# TODO: 减均值除方差
|
||||||
|
|
||||||
|
q = batched_inputs['q_img']
|
||||||
|
x = batched_inputs['p_img']
|
||||||
|
qf = self.cnn(q)
|
||||||
|
xf = self.cnn(x)
|
||||||
|
# L2 norm TODO: 查看具体数值,确定他到底是什么范数
|
||||||
|
qf = self._norm(qf)
|
||||||
|
xf = self._norm(xf)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
qf_full, qf_part_0, qf_part_1, qf_part_2 = torch.split(qf, [2048, 512, 512, 512], dim=-1)
|
||||||
|
xf_full, xf_part_0, xf_part_1, xf_part_2 = torch.split(xf, [2048, 512, 512, 512], dim=-1)
|
||||||
|
|
||||||
|
m_full = self.dense_matching_full(
|
||||||
|
torch.cat([qf_full, xf_full, (qf_full - xf_full).abs(), qf_full * xf_full], dim=-1))
|
||||||
|
m_part_0 = self.dense_matching_part_0(
|
||||||
|
torch.cat([qf_part_0, xf_part_0, (qf_part_0 - xf_part_0).abs(), qf_part_0 * xf_part_0], dim=-1))
|
||||||
|
m_part_1 = self.dense_matching_part_1(
|
||||||
|
torch.cat([qf_part_1, xf_part_1, (qf_part_1 - xf_part_1).abs(), qf_part_1 * xf_part_1], dim=-1))
|
||||||
|
m_part_2 = self.dense_matching_part_2(
|
||||||
|
torch.cat([qf_part_2, xf_part_2, (qf_part_2 - xf_part_2).abs(), qf_part_2 * xf_part_2], dim=-1))
|
||||||
|
|
||||||
|
m_all = self.dense_matching_all(torch.cat([m_full, m_part_0, m_part_1, m_part_2], dim=-1))
|
||||||
|
|
||||||
|
return m_all.squeeze()
|
||||||
|
|
||||||
|
def losses(self, outputs, gt_labels):
|
||||||
|
"""
|
||||||
|
Compute loss from modeling's outputs, the loss function input arguments
|
||||||
|
must be the same as the outputs of the model forwarding.
|
||||||
|
"""
|
||||||
|
# model predictions
|
||||||
|
pred_class_logits = outputs['pred_class_logits'].detach()
|
||||||
|
cls_outputs = outputs['cls_outputs']
|
||||||
|
|
||||||
|
# Log prediction accuracy
|
||||||
|
log_accuracy(pred_class_logits, gt_labels)
|
||||||
|
|
||||||
|
loss_dict = {}
|
||||||
|
loss_names = self.loss_kwargs['loss_names']
|
||||||
|
|
||||||
|
if 'CrossEntropyLoss' in loss_names:
|
||||||
|
ce_kwargs = self.loss_kwargs.get('ce')
|
||||||
|
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||||
|
cls_outputs,
|
||||||
|
gt_labels,
|
||||||
|
ce_kwargs.get('eps'),
|
||||||
|
ce_kwargs.get('alpha')
|
||||||
|
) * ce_kwargs.get('scale')
|
||||||
|
|
||||||
|
|
||||||
|
return loss_dict
|
Loading…
x
Reference in New Issue
Block a user