mirror of https://github.com/JDAI-CV/fast-reid.git
实现和线上模型一模一样的PCBNet
parent
5a075c1fe8
commit
65cfc515d9
|
@ -4,7 +4,6 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from se_pcb_net import build_senet_pcb_backbone
|
||||
from .build import build_backbone, BACKBONE_REGISTRY
|
||||
from .mobilenet import build_mobilenetv2_backbone
|
||||
from .osnet import build_osnet_backbone
|
||||
|
@ -15,3 +14,4 @@ from .resnet import build_resnet_backbone
|
|||
from .resnext import build_resnext_backbone
|
||||
from .shufflenet import build_shufflenetv2_backbone
|
||||
from .vision_transformer import build_vit_backbone
|
||||
from .se_pcb_net import build_senet_pcb_backbone
|
||||
|
|
|
@ -22,13 +22,13 @@ class SePcbNet(nn.Module):
|
|||
part_num: int,
|
||||
embedding_dim: int,
|
||||
part_dim: int,
|
||||
last_stride: Tuple[int, int]
|
||||
last_stride: int,
|
||||
):
|
||||
super(SePcbNet, self).__init__()
|
||||
self.part_num = part_num
|
||||
self.embedding_dim = embedding_dim
|
||||
self.part_dim = part_dim
|
||||
self.last_stride = last_stride
|
||||
self.last_stride = (last_stride, last_stride)
|
||||
|
||||
self.cnn = pretrainedmodels.__dict__["se_resnext101_32x4d"](pretrained='imagenet')
|
||||
self.cnn.layer4[0].downsample[0].stride = self.last_stride
|
||||
|
@ -40,7 +40,7 @@ class SePcbNet(nn.Module):
|
|||
setattr(self, 'reduction_' + str(i),
|
||||
nn.Sequential(
|
||||
nn.Conv2d(self.embedding_dim, self.part_dim, (1, 1), bias=False),
|
||||
nn.BatchNorm2d(self.part_num),
|
||||
nn.BatchNorm2d(self.part_dim),
|
||||
nn.ReLU()
|
||||
))
|
||||
|
||||
|
@ -70,7 +70,6 @@ class SePcbNet(nn.Module):
|
|||
}
|
||||
|
||||
def random_init(self):
|
||||
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
|
@ -82,6 +81,8 @@ class SePcbNet(nn.Module):
|
|||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
self.cnn.layer0.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_senet_pcb_backbone(cfg: CfgNode):
|
||||
|
@ -99,10 +100,10 @@ def build_senet_pcb_backbone(cfg: CfgNode):
|
|||
if pretrain:
|
||||
if pretrain_path:
|
||||
try:
|
||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
|
||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
new_k = '.'.join(k.split('.')[2:])
|
||||
new_k = 'cnn.' + k
|
||||
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
new_state_dict[new_k] = state_dict[k]
|
||||
state_dict = new_state_dict
|
||||
|
|
|
@ -79,37 +79,31 @@ class PcbHead(nn.Module):
|
|||
self.reset_parameters()
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
full = features['full']
|
||||
parts = features['parts']
|
||||
bsz = full.size(0)
|
||||
|
||||
# normalize
|
||||
full = self._normalize(full)
|
||||
parts = self._normalize(parts)
|
||||
query_feature = features['query']
|
||||
gallery_feature = features['gallery']
|
||||
|
||||
# split features into pair
|
||||
query_full, gallery_full = self._split_features(full, bsz)
|
||||
query_part_0, gallery_part_0 = self._split_features(parts[0], bsz)
|
||||
query_part_1, gallery_part_1 = self._split_features(parts[1], bsz)
|
||||
query_part_2, gallery_part_2 = self._split_features(parts[2], bsz)
|
||||
query_full, query_part_0, query_part_1, query_part_2 = torch.split(query_feature,
|
||||
[self.full_dim, self.part_dim, self.part_dim, self.part_dim], dim=-1)
|
||||
gallery_full, gallery_part_0, gallery_part_1, gallery_part_2 = torch.split(gallery_feature,
|
||||
[self.full_dim, self.part_dim, self.part_dim, self.part_dim], dim=-1)
|
||||
|
||||
m_full = self.match_full(
|
||||
torch.cat([query_full, gallery_full, query_full - gallery_full,
|
||||
torch.cat([query_full, gallery_full, (query_full - gallery_full).abs(),
|
||||
query_full * gallery_full], dim=-1)
|
||||
)
|
||||
|
||||
m_part_0 = self.match_part_0(
|
||||
torch.cat([query_part_0, gallery_part_0, query_part_0 - gallery_part_0,
|
||||
torch.cat([query_part_0, gallery_part_0, (query_part_0 - gallery_part_0).abs(),
|
||||
query_part_0 * gallery_part_0], dim=-1)
|
||||
)
|
||||
|
||||
m_part_1 = self.match_part_1(
|
||||
torch.cat([query_part_1, gallery_part_1, query_part_1 - gallery_part_1,
|
||||
torch.cat([query_part_1, gallery_part_1, (query_part_1 - gallery_part_1).abs(),
|
||||
query_part_1 * gallery_part_1], dim=-1)
|
||||
)
|
||||
|
||||
m_part_2 = self.match_part_2(
|
||||
torch.cat([query_part_2, gallery_part_2, query_part_2 - gallery_part_2,
|
||||
torch.cat([query_part_2, gallery_part_2, (query_part_2 - gallery_part_2).abs(),
|
||||
query_part_2 * gallery_part_2], dim=-1)
|
||||
)
|
||||
|
||||
|
|
|
@ -14,3 +14,4 @@ from .moco import MoCo
|
|||
from .distiller import Distiller
|
||||
from .metric import Metric
|
||||
from .pcb import PCB
|
||||
from .pcb_online import PcbOnline
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# coding: utf-8
|
||||
"""
|
||||
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
|
||||
实现和线上一模一样的PCB
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
|
||||
from fastreid.modeling.meta_arch import Baseline
|
||||
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class PcbOnline(Baseline):
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
bsz = int(images.size(0) / 2)
|
||||
feats = self.backbone(images)
|
||||
feats = torch.cat((feats['full'], feats['parts'][0], feats['parts'][1], feats['parts'][2]), 1)
|
||||
feats = F.normalize(feats, p=2.0, dim=-1)
|
||||
|
||||
qf = feats[0: bsz * 2: 2, ...]
|
||||
xf = feats[1: bsz * 2: 2, ...]
|
||||
outputs = self.heads({'query': qf, 'gallery': xf})
|
||||
|
||||
if self.training:
|
||||
targets = batched_inputs['targets']
|
||||
losses = self.losses(outputs, targets)
|
||||
return losses
|
||||
else:
|
||||
return outputs
|
||||
|
||||
def losses(self, outputs, gt_labels):
|
||||
"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
# model predictions
|
||||
pred_class_logits = outputs['pred_class_logits'].detach()
|
||||
cls_outputs = outputs['cls_outputs']
|
||||
|
||||
# Log prediction accuracy
|
||||
log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
||||
if 'CrossEntropyLoss' in loss_names:
|
||||
ce_kwargs = self.loss_kwargs.get('ce')
|
||||
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale')
|
||||
|
||||
|
||||
return loss_dict
|
|
@ -1,8 +1,8 @@
|
|||
_BASE_: base.yaml
|
||||
|
||||
MODEL:
|
||||
META_ARCHITECTURE: PCB
|
||||
|
||||
META_ARCHITECTURE: PcbOnline
|
||||
|
||||
PCB:
|
||||
PART_NUM: 3
|
||||
PART_DIM: 512
|
||||
|
@ -14,10 +14,12 @@ MODEL:
|
|||
EMBEDDING_DIM: 512
|
||||
|
||||
BACKBONE:
|
||||
NAME: build_resnet_backbone
|
||||
PRETRAIN: True
|
||||
PRETRAIN_PATH: /home/apps/.cache/torch/hub/checkpoints/se_resnext101_32x4d-3b2fe3d8.pth
|
||||
NAME: build_senet_pcb_backbone
|
||||
DEPTH: 101x
|
||||
NORM: BN
|
||||
LAST_STRIDE: 2
|
||||
LAST_STRIDE: 1
|
||||
FEAT_DIM: 512
|
||||
PRETRAIN: True
|
||||
WITH_IBN: True
|
||||
|
@ -46,11 +48,34 @@ INPUT:
|
|||
ENABLED: True
|
||||
SIZE: [270, 260]
|
||||
SCALE: [0.8, 1.2]
|
||||
RATIO: [3./4, 4./3]
|
||||
RATIO: [0.75, 1.33333333]
|
||||
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
OPT: SGD
|
||||
SCHED: CosineAnnealingLR
|
||||
|
||||
BASE_LR: 0.001
|
||||
MOMENTUM: 0.9
|
||||
NESTEROV: False
|
||||
|
||||
BIAS_LR_FACTOR: 1.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.
|
||||
ETA_MIN_LR: 0.00003
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 1000
|
||||
|
||||
IMS_PER_BATCH: 40
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("ShoeDataset",)
|
||||
TESTS: ("ShoeDataset", "OnlineDataset")
|
||||
TESTS: ("ShoeDataset",)
|
||||
|
||||
OUTPUT_DIR: projects/FastShoe/logs/online-pcb
|
||||
|
||||
|
|
|
@ -28,14 +28,9 @@ class PairDataset(Dataset):
|
|||
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
|
||||
seed_all_rng(12345)
|
||||
|
||||
# if self.mode == 'train':
|
||||
# # make negative sample come from all negative folders when train
|
||||
# self.neg_folders = sum(self.neg_folders, list())
|
||||
|
||||
def __len__(self):
|
||||
if self.mode == 'test':
|
||||
return len(self.pos_folders) * 10
|
||||
|
||||
return len(self.pos_folders)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
@ -43,9 +38,6 @@ class PairDataset(Dataset):
|
|||
idx = int(idx / 10)
|
||||
|
||||
pf = self.pos_folders[idx]
|
||||
# if self.mode == 'train':
|
||||
# nf = self.neg_folders
|
||||
# else:
|
||||
nf = self.neg_folders[idx]
|
||||
|
||||
label = 1
|
||||
|
|
Loading…
Reference in New Issue