mirror of https://github.com/JDAI-CV/fast-reid.git
实现和线上模型一模一样的PCBNet
parent
5a075c1fe8
commit
65cfc515d9
|
@ -4,7 +4,6 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from se_pcb_net import build_senet_pcb_backbone
|
|
||||||
from .build import build_backbone, BACKBONE_REGISTRY
|
from .build import build_backbone, BACKBONE_REGISTRY
|
||||||
from .mobilenet import build_mobilenetv2_backbone
|
from .mobilenet import build_mobilenetv2_backbone
|
||||||
from .osnet import build_osnet_backbone
|
from .osnet import build_osnet_backbone
|
||||||
|
@ -15,3 +14,4 @@ from .resnet import build_resnet_backbone
|
||||||
from .resnext import build_resnext_backbone
|
from .resnext import build_resnext_backbone
|
||||||
from .shufflenet import build_shufflenetv2_backbone
|
from .shufflenet import build_shufflenetv2_backbone
|
||||||
from .vision_transformer import build_vit_backbone
|
from .vision_transformer import build_vit_backbone
|
||||||
|
from .se_pcb_net import build_senet_pcb_backbone
|
||||||
|
|
|
@ -22,13 +22,13 @@ class SePcbNet(nn.Module):
|
||||||
part_num: int,
|
part_num: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
part_dim: int,
|
part_dim: int,
|
||||||
last_stride: Tuple[int, int]
|
last_stride: int,
|
||||||
):
|
):
|
||||||
super(SePcbNet, self).__init__()
|
super(SePcbNet, self).__init__()
|
||||||
self.part_num = part_num
|
self.part_num = part_num
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.part_dim = part_dim
|
self.part_dim = part_dim
|
||||||
self.last_stride = last_stride
|
self.last_stride = (last_stride, last_stride)
|
||||||
|
|
||||||
self.cnn = pretrainedmodels.__dict__["se_resnext101_32x4d"](pretrained='imagenet')
|
self.cnn = pretrainedmodels.__dict__["se_resnext101_32x4d"](pretrained='imagenet')
|
||||||
self.cnn.layer4[0].downsample[0].stride = self.last_stride
|
self.cnn.layer4[0].downsample[0].stride = self.last_stride
|
||||||
|
@ -40,7 +40,7 @@ class SePcbNet(nn.Module):
|
||||||
setattr(self, 'reduction_' + str(i),
|
setattr(self, 'reduction_' + str(i),
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Conv2d(self.embedding_dim, self.part_dim, (1, 1), bias=False),
|
nn.Conv2d(self.embedding_dim, self.part_dim, (1, 1), bias=False),
|
||||||
nn.BatchNorm2d(self.part_num),
|
nn.BatchNorm2d(self.part_dim),
|
||||||
nn.ReLU()
|
nn.ReLU()
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@ -70,7 +70,6 @@ class SePcbNet(nn.Module):
|
||||||
}
|
}
|
||||||
|
|
||||||
def random_init(self):
|
def random_init(self):
|
||||||
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
@ -82,6 +81,8 @@ class SePcbNet(nn.Module):
|
||||||
m.weight.data.fill_(1)
|
m.weight.data.fill_(1)
|
||||||
m.bias.data.zero_()
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
self.cnn.layer0.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
|
||||||
|
|
||||||
|
|
||||||
@BACKBONE_REGISTRY.register()
|
@BACKBONE_REGISTRY.register()
|
||||||
def build_senet_pcb_backbone(cfg: CfgNode):
|
def build_senet_pcb_backbone(cfg: CfgNode):
|
||||||
|
@ -99,10 +100,10 @@ def build_senet_pcb_backbone(cfg: CfgNode):
|
||||||
if pretrain:
|
if pretrain:
|
||||||
if pretrain_path:
|
if pretrain_path:
|
||||||
try:
|
try:
|
||||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
|
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
new_k = '.'.join(k.split('.')[2:])
|
new_k = 'cnn.' + k
|
||||||
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||||
new_state_dict[new_k] = state_dict[k]
|
new_state_dict[new_k] = state_dict[k]
|
||||||
state_dict = new_state_dict
|
state_dict = new_state_dict
|
||||||
|
|
|
@ -79,37 +79,31 @@ class PcbHead(nn.Module):
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def forward(self, features, targets=None):
|
def forward(self, features, targets=None):
|
||||||
full = features['full']
|
query_feature = features['query']
|
||||||
parts = features['parts']
|
gallery_feature = features['gallery']
|
||||||
bsz = full.size(0)
|
|
||||||
|
|
||||||
# normalize
|
|
||||||
full = self._normalize(full)
|
|
||||||
parts = self._normalize(parts)
|
|
||||||
|
|
||||||
# split features into pair
|
query_full, query_part_0, query_part_1, query_part_2 = torch.split(query_feature,
|
||||||
query_full, gallery_full = self._split_features(full, bsz)
|
[self.full_dim, self.part_dim, self.part_dim, self.part_dim], dim=-1)
|
||||||
query_part_0, gallery_part_0 = self._split_features(parts[0], bsz)
|
gallery_full, gallery_part_0, gallery_part_1, gallery_part_2 = torch.split(gallery_feature,
|
||||||
query_part_1, gallery_part_1 = self._split_features(parts[1], bsz)
|
[self.full_dim, self.part_dim, self.part_dim, self.part_dim], dim=-1)
|
||||||
query_part_2, gallery_part_2 = self._split_features(parts[2], bsz)
|
|
||||||
|
|
||||||
m_full = self.match_full(
|
m_full = self.match_full(
|
||||||
torch.cat([query_full, gallery_full, query_full - gallery_full,
|
torch.cat([query_full, gallery_full, (query_full - gallery_full).abs(),
|
||||||
query_full * gallery_full], dim=-1)
|
query_full * gallery_full], dim=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
m_part_0 = self.match_part_0(
|
m_part_0 = self.match_part_0(
|
||||||
torch.cat([query_part_0, gallery_part_0, query_part_0 - gallery_part_0,
|
torch.cat([query_part_0, gallery_part_0, (query_part_0 - gallery_part_0).abs(),
|
||||||
query_part_0 * gallery_part_0], dim=-1)
|
query_part_0 * gallery_part_0], dim=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
m_part_1 = self.match_part_1(
|
m_part_1 = self.match_part_1(
|
||||||
torch.cat([query_part_1, gallery_part_1, query_part_1 - gallery_part_1,
|
torch.cat([query_part_1, gallery_part_1, (query_part_1 - gallery_part_1).abs(),
|
||||||
query_part_1 * gallery_part_1], dim=-1)
|
query_part_1 * gallery_part_1], dim=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
m_part_2 = self.match_part_2(
|
m_part_2 = self.match_part_2(
|
||||||
torch.cat([query_part_2, gallery_part_2, query_part_2 - gallery_part_2,
|
torch.cat([query_part_2, gallery_part_2, (query_part_2 - gallery_part_2).abs(),
|
||||||
query_part_2 * gallery_part_2], dim=-1)
|
query_part_2 * gallery_part_2], dim=-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -14,3 +14,4 @@ from .moco import MoCo
|
||||||
from .distiller import Distiller
|
from .distiller import Distiller
|
||||||
from .metric import Metric
|
from .metric import Metric
|
||||||
from .pcb import PCB
|
from .pcb import PCB
|
||||||
|
from .pcb_online import PcbOnline
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
# coding: utf-8
|
||||||
|
"""
|
||||||
|
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
|
||||||
|
实现和线上一模一样的PCB
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
|
||||||
|
from fastreid.modeling.meta_arch import Baseline
|
||||||
|
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
|
@META_ARCH_REGISTRY.register()
|
||||||
|
class PcbOnline(Baseline):
|
||||||
|
|
||||||
|
def forward(self, batched_inputs):
|
||||||
|
images = self.preprocess_image(batched_inputs)
|
||||||
|
bsz = int(images.size(0) / 2)
|
||||||
|
feats = self.backbone(images)
|
||||||
|
feats = torch.cat((feats['full'], feats['parts'][0], feats['parts'][1], feats['parts'][2]), 1)
|
||||||
|
feats = F.normalize(feats, p=2.0, dim=-1)
|
||||||
|
|
||||||
|
qf = feats[0: bsz * 2: 2, ...]
|
||||||
|
xf = feats[1: bsz * 2: 2, ...]
|
||||||
|
outputs = self.heads({'query': qf, 'gallery': xf})
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
targets = batched_inputs['targets']
|
||||||
|
losses = self.losses(outputs, targets)
|
||||||
|
return losses
|
||||||
|
else:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def losses(self, outputs, gt_labels):
|
||||||
|
"""
|
||||||
|
Compute loss from modeling's outputs, the loss function input arguments
|
||||||
|
must be the same as the outputs of the model forwarding.
|
||||||
|
"""
|
||||||
|
# model predictions
|
||||||
|
pred_class_logits = outputs['pred_class_logits'].detach()
|
||||||
|
cls_outputs = outputs['cls_outputs']
|
||||||
|
|
||||||
|
# Log prediction accuracy
|
||||||
|
log_accuracy(pred_class_logits, gt_labels)
|
||||||
|
|
||||||
|
loss_dict = {}
|
||||||
|
loss_names = self.loss_kwargs['loss_names']
|
||||||
|
|
||||||
|
if 'CrossEntropyLoss' in loss_names:
|
||||||
|
ce_kwargs = self.loss_kwargs.get('ce')
|
||||||
|
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||||
|
cls_outputs,
|
||||||
|
gt_labels,
|
||||||
|
ce_kwargs.get('eps'),
|
||||||
|
ce_kwargs.get('alpha')
|
||||||
|
) * ce_kwargs.get('scale')
|
||||||
|
|
||||||
|
|
||||||
|
return loss_dict
|
|
@ -1,8 +1,8 @@
|
||||||
_BASE_: base.yaml
|
_BASE_: base.yaml
|
||||||
|
|
||||||
MODEL:
|
MODEL:
|
||||||
META_ARCHITECTURE: PCB
|
META_ARCHITECTURE: PcbOnline
|
||||||
|
|
||||||
PCB:
|
PCB:
|
||||||
PART_NUM: 3
|
PART_NUM: 3
|
||||||
PART_DIM: 512
|
PART_DIM: 512
|
||||||
|
@ -14,10 +14,12 @@ MODEL:
|
||||||
EMBEDDING_DIM: 512
|
EMBEDDING_DIM: 512
|
||||||
|
|
||||||
BACKBONE:
|
BACKBONE:
|
||||||
NAME: build_resnet_backbone
|
PRETRAIN: True
|
||||||
|
PRETRAIN_PATH: /home/apps/.cache/torch/hub/checkpoints/se_resnext101_32x4d-3b2fe3d8.pth
|
||||||
|
NAME: build_senet_pcb_backbone
|
||||||
DEPTH: 101x
|
DEPTH: 101x
|
||||||
NORM: BN
|
NORM: BN
|
||||||
LAST_STRIDE: 2
|
LAST_STRIDE: 1
|
||||||
FEAT_DIM: 512
|
FEAT_DIM: 512
|
||||||
PRETRAIN: True
|
PRETRAIN: True
|
||||||
WITH_IBN: True
|
WITH_IBN: True
|
||||||
|
@ -46,11 +48,34 @@ INPUT:
|
||||||
ENABLED: True
|
ENABLED: True
|
||||||
SIZE: [270, 260]
|
SIZE: [270, 260]
|
||||||
SCALE: [0.8, 1.2]
|
SCALE: [0.8, 1.2]
|
||||||
RATIO: [3./4, 4./3]
|
RATIO: [0.75, 1.33333333]
|
||||||
|
|
||||||
|
DATALOADER:
|
||||||
|
NUM_WORKERS: 8
|
||||||
|
|
||||||
|
SOLVER:
|
||||||
|
OPT: SGD
|
||||||
|
SCHED: CosineAnnealingLR
|
||||||
|
|
||||||
|
BASE_LR: 0.001
|
||||||
|
MOMENTUM: 0.9
|
||||||
|
NESTEROV: False
|
||||||
|
|
||||||
|
BIAS_LR_FACTOR: 1.
|
||||||
|
WEIGHT_DECAY: 0.0005
|
||||||
|
WEIGHT_DECAY_BIAS: 0.
|
||||||
|
ETA_MIN_LR: 0.00003
|
||||||
|
|
||||||
|
WARMUP_FACTOR: 0.1
|
||||||
|
WARMUP_ITERS: 1000
|
||||||
|
|
||||||
|
IMS_PER_BATCH: 40
|
||||||
|
|
||||||
|
TEST:
|
||||||
|
IMS_PER_BATCH: 64
|
||||||
|
|
||||||
DATASETS:
|
DATASETS:
|
||||||
NAMES: ("ShoeDataset",)
|
NAMES: ("ShoeDataset",)
|
||||||
TESTS: ("ShoeDataset", "OnlineDataset")
|
TESTS: ("ShoeDataset",)
|
||||||
|
|
||||||
OUTPUT_DIR: projects/FastShoe/logs/online-pcb
|
OUTPUT_DIR: projects/FastShoe/logs/online-pcb
|
||||||
|
|
||||||
|
|
|
@ -28,14 +28,9 @@ class PairDataset(Dataset):
|
||||||
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
|
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
|
||||||
seed_all_rng(12345)
|
seed_all_rng(12345)
|
||||||
|
|
||||||
# if self.mode == 'train':
|
|
||||||
# # make negative sample come from all negative folders when train
|
|
||||||
# self.neg_folders = sum(self.neg_folders, list())
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self.mode == 'test':
|
if self.mode == 'test':
|
||||||
return len(self.pos_folders) * 10
|
return len(self.pos_folders) * 10
|
||||||
|
|
||||||
return len(self.pos_folders)
|
return len(self.pos_folders)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
@ -43,9 +38,6 @@ class PairDataset(Dataset):
|
||||||
idx = int(idx / 10)
|
idx = int(idx / 10)
|
||||||
|
|
||||||
pf = self.pos_folders[idx]
|
pf = self.pos_folders[idx]
|
||||||
# if self.mode == 'train':
|
|
||||||
# nf = self.neg_folders
|
|
||||||
# else:
|
|
||||||
nf = self.neg_folders[idx]
|
nf = self.neg_folders[idx]
|
||||||
|
|
||||||
label = 1
|
label = 1
|
||||||
|
|
Loading…
Reference in New Issue