2020-02-10 07:38:56 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: liaoxingyu
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
|
2020-05-25 23:39:11 +08:00
|
|
|
import torch
|
2020-04-19 12:54:01 +08:00
|
|
|
from torch import nn
|
2020-02-10 07:38:56 +08:00
|
|
|
|
2020-06-12 16:34:03 +08:00
|
|
|
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
|
2020-05-01 09:02:46 +08:00
|
|
|
from fastreid.modeling.backbones import build_backbone
|
|
|
|
from fastreid.modeling.heads import build_reid_heads
|
2020-07-06 16:57:43 +08:00
|
|
|
from fastreid.modeling.losses import *
|
2020-02-10 07:38:56 +08:00
|
|
|
from .build import META_ARCH_REGISTRY
|
|
|
|
|
|
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
|
|
class Baseline(nn.Module):
|
|
|
|
def __init__(self, cfg):
|
|
|
|
super().__init__()
|
2020-03-25 10:58:26 +08:00
|
|
|
self._cfg = cfg
|
2020-07-06 16:57:43 +08:00
|
|
|
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
|
|
|
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
|
|
|
|
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
|
|
|
|
|
2020-03-25 10:58:26 +08:00
|
|
|
# backbone
|
2020-02-10 07:38:56 +08:00
|
|
|
self.backbone = build_backbone(cfg)
|
|
|
|
|
2020-03-25 10:58:26 +08:00
|
|
|
# head
|
2020-05-28 13:49:39 +08:00
|
|
|
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
2020-06-12 16:34:03 +08:00
|
|
|
if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d()
|
2020-05-28 13:49:39 +08:00
|
|
|
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
|
|
|
|
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
|
2020-06-12 16:34:03 +08:00
|
|
|
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
|
2020-05-28 13:49:39 +08:00
|
|
|
elif pool_type == "identity": pool_layer = nn.Identity()
|
2020-03-25 10:58:26 +08:00
|
|
|
else:
|
2020-05-28 13:49:39 +08:00
|
|
|
raise KeyError(f"{pool_type} is invalid, please choose from "
|
|
|
|
f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.")
|
2020-04-24 12:16:18 +08:00
|
|
|
|
|
|
|
in_feat = cfg.MODEL.HEADS.IN_FEAT
|
2020-04-29 21:29:48 +08:00
|
|
|
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
|
|
|
self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
|
2020-02-18 21:01:23 +08:00
|
|
|
|
2020-05-25 23:39:11 +08:00
|
|
|
@property
|
|
|
|
def device(self):
|
|
|
|
return self.pixel_mean.device
|
2020-02-10 07:38:56 +08:00
|
|
|
|
2020-05-25 23:39:11 +08:00
|
|
|
def forward(self, batched_inputs):
|
|
|
|
images = self.preprocess_image(batched_inputs)
|
2020-07-06 16:57:43 +08:00
|
|
|
features = self.backbone(images)
|
2020-05-25 23:39:11 +08:00
|
|
|
|
2020-07-06 16:57:43 +08:00
|
|
|
if self.training:
|
|
|
|
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
|
|
|
targets = batched_inputs["targets"].long().to(self.device)
|
2020-03-25 10:58:26 +08:00
|
|
|
|
2020-07-10 16:26:35 +08:00
|
|
|
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
|
|
|
|
# may be larger than that in the original dataset, so the circle/arcface will
|
|
|
|
# throw an error. We just set all the targets to 0 to avoid this problem.
|
|
|
|
if targets.sum() < 0: targets.zero_()
|
|
|
|
|
2020-07-14 11:58:06 +08:00
|
|
|
return self.heads(features, targets), targets
|
2020-07-06 16:57:43 +08:00
|
|
|
else:
|
2020-07-14 11:58:06 +08:00
|
|
|
return self.heads(features)
|
2020-03-25 10:58:26 +08:00
|
|
|
|
2020-05-25 23:39:11 +08:00
|
|
|
def preprocess_image(self, batched_inputs):
|
|
|
|
"""
|
|
|
|
Normalize and batch the input images.
|
|
|
|
"""
|
2020-07-06 16:57:43 +08:00
|
|
|
images = batched_inputs["images"].to(self.device)
|
2020-06-05 11:23:11 +08:00
|
|
|
# images = batched_inputs
|
2020-05-25 23:39:11 +08:00
|
|
|
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
|
|
|
return images
|
|
|
|
|
2020-07-14 11:58:06 +08:00
|
|
|
def losses(self, outputs, gt_labels):
|
2020-07-10 16:26:35 +08:00
|
|
|
r"""
|
|
|
|
Compute loss from modeling's outputs, the loss function input arguments
|
|
|
|
must be the same as the outputs of the model forwarding.
|
|
|
|
"""
|
2020-07-14 11:58:06 +08:00
|
|
|
cls_outputs, pred_class_logits, pred_features = outputs
|
2020-07-06 16:57:43 +08:00
|
|
|
loss_dict = {}
|
|
|
|
loss_names = self._cfg.MODEL.LOSSES.NAME
|
|
|
|
|
2020-07-14 11:58:06 +08:00
|
|
|
# Log prediction accuracy
|
|
|
|
CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels)
|
|
|
|
|
2020-07-06 16:57:43 +08:00
|
|
|
if "CrossEntropyLoss" in loss_names:
|
|
|
|
loss_dict['loss_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
|
|
|
|
|
|
|
|
if "TripletLoss" in loss_names:
|
|
|
|
loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
|
|
|
|
|
|
|
|
return loss_dict
|