实现和线上模型一模一样的PCBNet

pull/608/head
zuchen.wang 2021-11-02 17:56:05 +08:00
parent 5a075c1fe8
commit 65cfc515d9
7 changed files with 111 additions and 38 deletions

View File

@ -4,7 +4,6 @@
@contact: sherlockliao01@gmail.com
"""
from se_pcb_net import build_senet_pcb_backbone
from .build import build_backbone, BACKBONE_REGISTRY
from .mobilenet import build_mobilenetv2_backbone
from .osnet import build_osnet_backbone
@ -15,3 +14,4 @@ from .resnet import build_resnet_backbone
from .resnext import build_resnext_backbone
from .shufflenet import build_shufflenetv2_backbone
from .vision_transformer import build_vit_backbone
from .se_pcb_net import build_senet_pcb_backbone

View File

@ -22,13 +22,13 @@ class SePcbNet(nn.Module):
part_num: int,
embedding_dim: int,
part_dim: int,
last_stride: Tuple[int, int]
last_stride: int,
):
super(SePcbNet, self).__init__()
self.part_num = part_num
self.embedding_dim = embedding_dim
self.part_dim = part_dim
self.last_stride = last_stride
self.last_stride = (last_stride, last_stride)
self.cnn = pretrainedmodels.__dict__["se_resnext101_32x4d"](pretrained='imagenet')
self.cnn.layer4[0].downsample[0].stride = self.last_stride
@ -40,7 +40,7 @@ class SePcbNet(nn.Module):
setattr(self, 'reduction_' + str(i),
nn.Sequential(
nn.Conv2d(self.embedding_dim, self.part_dim, (1, 1), bias=False),
nn.BatchNorm2d(self.part_num),
nn.BatchNorm2d(self.part_dim),
nn.ReLU()
))
@ -70,7 +70,6 @@ class SePcbNet(nn.Module):
}
def random_init(self):
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@ -82,6 +81,8 @@ class SePcbNet(nn.Module):
m.weight.data.fill_(1)
m.bias.data.zero_()
self.cnn.layer0.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
@BACKBONE_REGISTRY.register()
def build_senet_pcb_backbone(cfg: CfgNode):
@ -99,10 +100,10 @@ def build_senet_pcb_backbone(cfg: CfgNode):
if pretrain:
if pretrain_path:
try:
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
new_state_dict = {}
for k in state_dict:
new_k = '.'.join(k.split('.')[2:])
new_k = 'cnn.' + k
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict

View File

@ -79,37 +79,31 @@ class PcbHead(nn.Module):
self.reset_parameters()
def forward(self, features, targets=None):
full = features['full']
parts = features['parts']
bsz = full.size(0)
query_feature = features['query']
gallery_feature = features['gallery']
# normalize
full = self._normalize(full)
parts = self._normalize(parts)
# split features into pair
query_full, gallery_full = self._split_features(full, bsz)
query_part_0, gallery_part_0 = self._split_features(parts[0], bsz)
query_part_1, gallery_part_1 = self._split_features(parts[1], bsz)
query_part_2, gallery_part_2 = self._split_features(parts[2], bsz)
query_full, query_part_0, query_part_1, query_part_2 = torch.split(query_feature,
[self.full_dim, self.part_dim, self.part_dim, self.part_dim], dim=-1)
gallery_full, gallery_part_0, gallery_part_1, gallery_part_2 = torch.split(gallery_feature,
[self.full_dim, self.part_dim, self.part_dim, self.part_dim], dim=-1)
m_full = self.match_full(
torch.cat([query_full, gallery_full, query_full - gallery_full,
torch.cat([query_full, gallery_full, (query_full - gallery_full).abs(),
query_full * gallery_full], dim=-1)
)
m_part_0 = self.match_part_0(
torch.cat([query_part_0, gallery_part_0, query_part_0 - gallery_part_0,
torch.cat([query_part_0, gallery_part_0, (query_part_0 - gallery_part_0).abs(),
query_part_0 * gallery_part_0], dim=-1)
)
m_part_1 = self.match_part_1(
torch.cat([query_part_1, gallery_part_1, query_part_1 - gallery_part_1,
torch.cat([query_part_1, gallery_part_1, (query_part_1 - gallery_part_1).abs(),
query_part_1 * gallery_part_1], dim=-1)
)
m_part_2 = self.match_part_2(
torch.cat([query_part_2, gallery_part_2, query_part_2 - gallery_part_2,
torch.cat([query_part_2, gallery_part_2, (query_part_2 - gallery_part_2).abs(),
query_part_2 * gallery_part_2], dim=-1)
)

View File

@ -14,3 +14,4 @@ from .moco import MoCo
from .distiller import Distiller
from .metric import Metric
from .pcb import PCB
from .pcb_online import PcbOnline

View File

@ -0,0 +1,60 @@
# coding: utf-8
"""
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
实现和线上一模一样的PCB
"""
import torch
import torch.nn.functional as F
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class PcbOnline(Baseline):
def forward(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
bsz = int(images.size(0) / 2)
feats = self.backbone(images)
feats = torch.cat((feats['full'], feats['parts'][0], feats['parts'][1], feats['parts'][2]), 1)
feats = F.normalize(feats, p=2.0, dim=-1)
qf = feats[0: bsz * 2: 2, ...]
xf = feats[1: bsz * 2: 2, ...]
outputs = self.heads({'query': qf, 'gallery': xf})
if self.training:
targets = batched_inputs['targets']
losses = self.losses(outputs, targets)
return losses
else:
return outputs
def losses(self, outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
pred_class_logits = outputs['pred_class_logits'].detach()
cls_outputs = outputs['cls_outputs']
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self.loss_kwargs['loss_names']
if 'CrossEntropyLoss' in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls'] = cross_entropy_loss(
cls_outputs,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale')
return loss_dict

View File

@ -1,7 +1,7 @@
_BASE_: base.yaml
MODEL:
META_ARCHITECTURE: PCB
META_ARCHITECTURE: PcbOnline
PCB:
PART_NUM: 3
@ -14,10 +14,12 @@ MODEL:
EMBEDDING_DIM: 512
BACKBONE:
NAME: build_resnet_backbone
PRETRAIN: True
PRETRAIN_PATH: /home/apps/.cache/torch/hub/checkpoints/se_resnext101_32x4d-3b2fe3d8.pth
NAME: build_senet_pcb_backbone
DEPTH: 101x
NORM: BN
LAST_STRIDE: 2
LAST_STRIDE: 1
FEAT_DIM: 512
PRETRAIN: True
WITH_IBN: True
@ -46,11 +48,34 @@ INPUT:
ENABLED: True
SIZE: [270, 260]
SCALE: [0.8, 1.2]
RATIO: [3./4, 4./3]
RATIO: [0.75, 1.33333333]
DATALOADER:
NUM_WORKERS: 8
SOLVER:
OPT: SGD
SCHED: CosineAnnealingLR
BASE_LR: 0.001
MOMENTUM: 0.9
NESTEROV: False
BIAS_LR_FACTOR: 1.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.
ETA_MIN_LR: 0.00003
WARMUP_FACTOR: 0.1
WARMUP_ITERS: 1000
IMS_PER_BATCH: 40
TEST:
IMS_PER_BATCH: 64
DATASETS:
NAMES: ("ShoeDataset",)
TESTS: ("ShoeDataset", "OnlineDataset")
TESTS: ("ShoeDataset",)
OUTPUT_DIR: projects/FastShoe/logs/online-pcb

View File

@ -28,14 +28,9 @@ class PairDataset(Dataset):
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
seed_all_rng(12345)
# if self.mode == 'train':
# # make negative sample come from all negative folders when train
# self.neg_folders = sum(self.neg_folders, list())
def __len__(self):
if self.mode == 'test':
return len(self.pos_folders) * 10
return len(self.pos_folders)
def __getitem__(self, idx):
@ -43,9 +38,6 @@ class PairDataset(Dataset):
idx = int(idx / 10)
pf = self.pos_folders[idx]
# if self.mode == 'train':
# nf = self.neg_folders
# else:
nf = self.neg_folders[idx]
label = 1