实现和线上一模一样的pcb

pull/608/head
zuchen.wang 2021-10-29 09:45:59 +08:00
parent eabe896ee8
commit 44a90ae025
4 changed files with 212 additions and 17 deletions

View File

@ -4,14 +4,14 @@
@contact: sherlockliao01@gmail.com
"""
from se_pcb_net import build_senet_pcb_backbone
from .build import build_backbone, BACKBONE_REGISTRY
from .resnet import build_resnet_backbone
from .osnet import build_osnet_backbone
from .resnest import build_resnest_backbone
from .resnext import build_resnext_backbone
from .regnet import build_regnet_backbone, build_effnet_backbone
from .shufflenet import build_shufflenetv2_backbone
from .mobilenet import build_mobilenetv2_backbone
from .osnet import build_osnet_backbone
from .regnet import build_regnet_backbone, build_effnet_backbone
from .repvgg import build_repvgg_backbone
from .resnest import build_resnest_backbone
from .resnet import build_resnet_backbone
from .resnext import build_resnext_backbone
from .shufflenet import build_shufflenetv2_backbone
from .vision_transformer import build_vit_backbone

View File

@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/28 23:06:51
# @Author : zuchen.wang@vipshop.com
# @File : senet.py
import logging
import math
from typing import Tuple
import pretrainedmodels
import torch
from torch import nn
from fastreid.config.config import CfgNode
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from .build import BACKBONE_REGISTRY
logger = logging.getLogger(__name__)
class SePcbNet(nn.Module):
def __init__(self,
part_num: int,
embedding_dim: int,
part_dim: int,
last_stride: Tuple[int, int]
):
super(SePcbNet, self).__init__()
self.part_num = part_num
self.embedding_dim = embedding_dim
self.part_dim = part_dim
self.last_stride = last_stride
self.cnn = pretrainedmodels.__dict__["se_resnext101_32x4d"](pretrained='imagenet')
self.cnn.layer4[0].downsample[0].stride = self.last_stride
self.cnn.layer4[0].conv2.stride = self.last_stride
self.cnn.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.avg_pool_part6 = nn.AdaptiveAvgPool2d((self.part_num, 1))
for i in range(self.part_num):
setattr(self, 'reduction_' + str(i),
nn.Sequential(
nn.Conv2d(self.embedding_dim, self.part_dim, (1, 1), bias=False),
nn.BatchNorm2d(self.part_num),
nn.ReLU()
))
self.random_init()
def forward(self, x):
x = self.cnn.layer0(x)
x = self.cnn.layer1(x)
x = self.cnn.layer2(x)
x = self.cnn.layer3(x)
x = self.cnn.layer4(x)
x_full = self.cnn.avg_pool(x)
x_full = x_full.reshape(x_full.shape[0], -1)
x_part = self.avg_pool_part6(x)
parts = []
for i in range(self.part_num):
reduction_i = getattr(self, 'reduction_' + str(i))
part_i = x_part[:, :, i: i + 1, :]
parts.append(reduction_i(part_i).squeeze())
return {
'full': x_full,
'parts': parts,
}
def random_init(self):
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
@BACKBONE_REGISTRY.register()
def build_senet_pcb_backbone(cfg: CfgNode):
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
part_num = cfg.MODEL.PCB.PART_NUM
part_dim = cfg.MODEL.PCB.PART_DIM
embedding_dim = cfg.MODEL.PCB.EMBEDDING_DIM
# fmt: on
model = SePcbNet(part_num=part_num, embedding_dim=embedding_dim, part_dim=part_dim, last_stride=last_stride)
if pretrain:
if pretrain_path:
try:
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
new_state_dict = {}
for k in state_dict:
new_k = '.'.join(k.split('.')[2:])
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict
logger.info(f"Loading pretrained model from {pretrain_path}")
except FileNotFoundError as e:
logger.error(f'{pretrain_path} is not found! Please check this path.')
raise e
except KeyError as e:
logger.error("State dict keys error! Please check the state dict.")
raise e
else:
logger.info('Not config pretrained mode with SePcbNet, the weights will be random init')
incompatible = model.load_state_dict(state_dict, strict=False)
if incompatible.missing_keys:
logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
return model

View File

@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
import logging
import torch
from torch import nn
import torch.nn.functional as F
from fastreid.config import configurable
import torch
import torch.nn.functional as F
from torch import nn
from fastreid.config import CfgNode
from fastreid.config import configurable
from fastreid.layers import weights_init_classifier
from fastreid.layers import any_softmax
from fastreid.modeling.heads import REID_HEADS_REGISTRY, EmbeddingHead
from fastreid.modeling.heads import REID_HEADS_REGISTRY
logger = logging.getLogger(__name__)
@ -45,10 +45,6 @@ class PcbHead(nn.Module):
self.full_dim = full_dim
self.part_dim = part_dim
self.embedding_dim = embedding_dim
# self.num_classes = num_classes
# self.cls_type = cls_type
# self.scale = scale
# self.margin = margin
self.match_full = nn.Sequential(
nn.Dropout(p=0.5),

View File

@ -0,0 +1,70 @@
# coding: utf-8
"""
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
实现和线上一模一样的PCB
"""
import torch
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class PcbOnline(Baseline):
def forward(self, batched_inputs):
# TODO: 减均值除方差
q = batched_inputs['q_img']
x = batched_inputs['p_img']
qf = self.cnn(q)
xf = self.cnn(x)
# L2 norm TODO: 查看具体数值,确定他到底是什么范数
qf = self._norm(qf)
xf = self._norm(xf)
qf_full, qf_part_0, qf_part_1, qf_part_2 = torch.split(qf, [2048, 512, 512, 512], dim=-1)
xf_full, xf_part_0, xf_part_1, xf_part_2 = torch.split(xf, [2048, 512, 512, 512], dim=-1)
m_full = self.dense_matching_full(
torch.cat([qf_full, xf_full, (qf_full - xf_full).abs(), qf_full * xf_full], dim=-1))
m_part_0 = self.dense_matching_part_0(
torch.cat([qf_part_0, xf_part_0, (qf_part_0 - xf_part_0).abs(), qf_part_0 * xf_part_0], dim=-1))
m_part_1 = self.dense_matching_part_1(
torch.cat([qf_part_1, xf_part_1, (qf_part_1 - xf_part_1).abs(), qf_part_1 * xf_part_1], dim=-1))
m_part_2 = self.dense_matching_part_2(
torch.cat([qf_part_2, xf_part_2, (qf_part_2 - xf_part_2).abs(), qf_part_2 * xf_part_2], dim=-1))
m_all = self.dense_matching_all(torch.cat([m_full, m_part_0, m_part_1, m_part_2], dim=-1))
return m_all.squeeze()
def losses(self, outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
pred_class_logits = outputs['pred_class_logits'].detach()
cls_outputs = outputs['cls_outputs']
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self.loss_kwargs['loss_names']
if 'CrossEntropyLoss' in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls'] = cross_entropy_loss(
cls_outputs,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale')
return loss_dict