fast-reid/fastreid/modeling/meta_arch/mf_network.py

140 lines
4.7 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
import torch.nn.functional as F
from .build import META_ARCH_REGISTRY
from ..model_utils import weights_init_kaiming
from ..backbones import build_backbone
from ..heads import build_reid_heads
from fastreid.modeling.layers import Flatten
@META_ARCH_REGISTRY.register()
class MF_net(nn.Module):
def __init__(self, cfg):
super().__init__()
self._cfg = cfg
# backbone
backbone = build_backbone(cfg)
self.backbone = nn.Sequential(
backbone.conv1,
backbone.bn1,
backbone.relu,
backbone.maxpool,
backbone.layer1,
backbone.layer2,
backbone.layer3
)
# body
self.res4 = backbone.layer4
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.maxpool = nn.AdaptiveMaxPool2d(1)
self.avgpool_2 = nn.AdaptiveAvgPool2d((2, 2))
self.maxpool_2 = nn.AdaptiveMaxPool2d((2, 2))
# branch 1
self.branch_1 = nn.Sequential(
Flatten(),
nn.BatchNorm1d(2048),
nn.LeakyReLU(0.1, True),
nn.Linear(2048, 512, bias=False),
)
self.branch_1.apply(weights_init_kaiming)
self.head1 = build_reid_heads(cfg, 512, nn.Identity())
# branch 2
self.branch_2 = nn.Sequential(
Flatten(),
nn.BatchNorm1d(8192),
nn.LeakyReLU(0.1, True),
nn.Linear(8192, 512, bias=False),
)
self.branch_2.apply(weights_init_kaiming)
self.head2 = build_reid_heads(cfg, 512, nn.Identity())
# branch 3
self.branch_3 = nn.Sequential(
Flatten(),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.1, True),
nn.Linear(1024, 512, bias=False),
)
self.branch_3.apply(weights_init_kaiming)
self.head3 = build_reid_heads(cfg, 512, nn.Identity())
def forward(self, inputs):
images = inputs["images"]
targets = inputs["targets"]
if not self.training:
pred_feat = self.inference(images)
return pred_feat, targets, inputs["camid"]
mid_feat = self.backbone(images)
feat = self.res4(mid_feat)
# branch 1
avg_feat1 = self.avgpool(feat)
max_feat1 = self.maxpool(feat)
feat1 = avg_feat1 + max_feat1
feat1 = self.branch_1(feat1)
logits_1, feat1 = self.head1(feat1, targets)
# branch 2
avg_feat2 = self.avgpool_2(feat)
max_feat2 = self.maxpool_2(feat)
feat2 = avg_feat2 + max_feat2
feat2 = self.branch_2(feat2)
logits_2, feat2 = self.head2(feat2, targets)
# branch 3
avg_feat3 = self.avgpool(mid_feat)
max_feat3 = self.maxpool(mid_feat)
feat3 = avg_feat3 + max_feat3
feat3 = self.branch_3(feat3)
logits_3, feat3 = self.head3(feat3, targets)
return logits_1, logits_2, logits_3, \
Flatten()(avg_feat1), Flatten()(avg_feat2), Flatten()(avg_feat3),\
Flatten()(max_feat1), Flatten()(max_feat2), Flatten()(max_feat3), targets
def inference(self, images):
assert not self.training
mid_feat = self.backbone(images)
feat = self.res4(mid_feat)
# branch 1
avg_feat1 = self.avgpool(feat)
max_feat1 = self.maxpool(feat)
feat1 = avg_feat1 + max_feat1
feat1 = self.branch_1(feat1)
pred_feat1 = self.head1(feat1)
# branch 2
avg_feat2 = self.avgpool_2(feat)
max_feat2 = self.maxpool_2(feat)
feat2 = avg_feat2 + max_feat2
feat2 = self.branch_2(feat2)
pred_feat2 = self.head2(feat2)
# branch 3
avg_feat3 = self.avgpool(mid_feat)
max_feat3 = self.maxpool(mid_feat)
feat3 = avg_feat3 + max_feat3
feat3 = self.branch_3(feat3)
pred_feat3 = self.head3(feat3)
pred_feat = torch.cat([pred_feat1, pred_feat2, pred_feat3], dim=1)
return F.normalize(pred_feat)
def losses(self, outputs):
loss_dict = {}
loss_dict.update(self.head1.losses(self._cfg, outputs[0], outputs[3], outputs[-1], 'b1_'))
loss_dict.update(self.head2.losses(self._cfg, outputs[1], outputs[4], outputs[-1], 'b2_'))
loss_dict.update(self.head3.losses(self._cfg, outputs[2], outputs[5], outputs[-1], 'b3_'))
loss_dict.update(self.head1.losses(self._cfg, None, outputs[6], outputs[-1], 'mp1_'))
loss_dict.update(self.head2.losses(self._cfg, None, outputs[7], outputs[-1], 'mp2_'))
loss_dict.update(self.head3.losses(self._cfg, None, outputs[8], outputs[-1], 'mp3_'))
return loss_dict