fast-reid/fastreid/modeling/heads/embedding_head.py

148 lines
4.4 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn.functional as F
from torch import nn
from fastreid.config import configurable
from fastreid.layers import *
from fastreid.layers import pooling, any_softmax
from fastreid.utils.weight_init import weights_init_kaiming
from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class EmbeddingHead(nn.Module):
"""
EmbeddingHead perform all feature aggregation in an embedding task, such as reid, image retrieval
and face recognition
It typically contains logic to
1. feature aggregation via global average pooling and generalized mean pooling
2. (optional) batchnorm, dimension reduction and etc.
2. (in training only) margin-based softmax logits computation
"""
@configurable
def __init__(
self,
*,
feat_dim,
embedding_dim,
num_classes,
neck_feat,
pool_type,
cls_type,
scale,
margin,
with_bnneck,
norm_type
):
"""
NOTE: this interface is experimental.
Args:
feat_dim:
embedding_dim:
num_classes:
neck_feat:
pool_type:
cls_type:
scale:
margin:
with_bnneck:
norm_type:
"""
super().__init__()
# Pooling layer
assert hasattr(pooling, pool_type), "Expected pool types are {}, " \
"but got {}".format(pooling.__all__, pool_type)
self.pool_layer = getattr(pooling, pool_type)()
self.neck_feat = neck_feat
neck = []
if embedding_dim > 0:
neck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
feat_dim = embedding_dim
if with_bnneck:
neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
self.bottleneck = nn.Sequential(*neck)
self.bottleneck.apply(weights_init_kaiming)
# Linear layer
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
"but got {}".format(any_softmax.__all__, cls_type)
self.weight = nn.Parameter(torch.normal(0, 0.01, (num_classes, feat_dim)))
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
@classmethod
def from_config(cls, cfg):
# fmt: off
feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
cls_type = cfg.MODEL.HEADS.CLS_LAYER
scale = cfg.MODEL.HEADS.SCALE
margin = cfg.MODEL.HEADS.MARGIN
with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
norm_type = cfg.MODEL.HEADS.NORM
# fmt: on
return {
'feat_dim': feat_dim,
'embedding_dim': embedding_dim,
'num_classes': num_classes,
'neck_feat': neck_feat,
'pool_type': pool_type,
'cls_type': cls_type,
'scale': scale,
'margin': margin,
'with_bnneck': with_bnneck,
'norm_type': norm_type
}
def forward(self, features, targets=None):
"""
See :class:`ReIDHeads.forward`.
"""
pool_feat = self.pool_layer(features)
neck_feat = self.bottleneck(pool_feat)
neck_feat = neck_feat[..., 0, 0]
# Evaluation
# fmt: off
if not self.training: return neck_feat
# fmt: on
# Training
if self.cls_layer.__class__.__name__ == 'Linear':
logits = F.linear(neck_feat, self.weight)
else:
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
cls_outputs = self.cls_layer(logits, targets)
# fmt: off
if self.neck_feat == 'before': feat = pool_feat[..., 0, 0]
elif self.neck_feat == 'after': feat = neck_feat
else: raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
# fmt: on
return {
"cls_outputs": cls_outputs,
"pred_class_logits": logits * self.cls_layer.s,
"features": feat,
}