refactor reid head

Summary: merge BNneckHead, LinearHead and ReductionHead into EmbeddingHead
because they are highly similar and can be prepared for ClsHead
pull/299/head
liaoxingyu 2020-09-10 10:57:37 +08:00
parent 77caa01e34
commit 4d573b8107
8 changed files with 44 additions and 208 deletions

View File

@ -32,10 +32,10 @@ _C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
_C.MODEL.BACKBONE.DEPTH = "50x"
_C.MODEL.BACKBONE.LAST_STRIDE = 1
# Backbone feature dimension
_C.MODEL.BACKBONE.FEAT_DIM = 2048
# Normalization method for the convolution layers.
_C.MODEL.BACKBONE.NORM = "BN"
# Mini-batch split of Ghost BN
_C.MODEL.BACKBONE.NORM_SPLIT = 1
# If use IBN block in backbone
_C.MODEL.BACKBONE.WITH_IBN = False
# If use SE block in backbone
@ -51,18 +51,15 @@ _C.MODEL.BACKBONE.PRETRAIN_PATH = ''
# REID HEADS options
# ---------------------------------------------------------------------------- #
_C.MODEL.HEADS = CN()
_C.MODEL.HEADS.NAME = "BNneckHead"
_C.MODEL.HEADS.NAME = "EmbeddingHead"
# Normalization method for the convolution layers.
_C.MODEL.HEADS.NORM = "BN"
# Mini-batch split of Ghost BN
_C.MODEL.HEADS.NORM_SPLIT = 1
# Number of identity
_C.MODEL.HEADS.NUM_CLASSES = 0
# Input feature dimension
_C.MODEL.HEADS.IN_FEAT = 2048
# Reduction dimension in head
_C.MODEL.HEADS.REDUCTION_DIM = 512
# Embedding dimension in head
_C.MODEL.HEADS.EMBEDDING_DIM = 0
# If use BNneck in embedding
_C.MODEL.HEADS.WITH_BNNECK = True
# Triplet feature using feature before(after) bnneck
_C.MODEL.HEADS.NECK_FEAT = "before" # options: before, after
# Pooling layer type
@ -274,4 +271,3 @@ _C.OUTPUT_DIR = "logs/"
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
# If input images have the same or similar sizes, benchmark is often helpful.
_C.CUDNN_BENCHMARK = False

View File

@ -373,8 +373,8 @@ class DefaultTrainer(SimpleTrainer):
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
# logger = logging.getLogger(__name__)
# logger.info("Model:\n{}".format(model))
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
return model
@classmethod

View File

@ -280,7 +280,7 @@ class AutogradProfiler(HookBase):
self._profiler.export_chrome_trace(out_file)
else:
# Support non-posix filesystems
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
with tempfile.TemporaryDirectory(prefix="fastreid_profiler") as d:
tmp_file = os.path.join(d, "tmp.json")
self._profiler.export_chrome_trace(tmp_file)
with open(tmp_file) as f:

View File

@ -7,6 +7,4 @@
from .build import REID_HEADS_REGISTRY, build_reid_heads
# import all the meta_arch, so they will be registered
from .linear_head import LinearHead
from .bnneck_head import BNneckHead
from .reduction_head import ReductionHead
from .embedding_head import EmbeddingHead

View File

@ -1,88 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from torch import nn
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class BNneckHead(nn.Module):
def __init__(self, cfg):
super().__init__()
# fmt: off
in_feat = cfg.MODEL.HEADS.IN_FEAT
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
elif pool_type == "identity": self.pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
# fmt: on
self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
self.bnneck.apply(weights_init_kaiming)
# identity classification layer
cls_type = cfg.MODEL.HEADS.CLS_LAYER
# fmt: off
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, in_feat, num_classes)
else:
raise KeyError(f"{cls_type} is invalid, please choose from "
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
# fmt: on
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
"""
See :class:`ReIDHeads.forward`.
"""
global_feat = self.pool_layer(features)
bn_feat = self.bnneck(global_feat)
bn_feat = bn_feat[..., 0, 0]
# Evaluation
# fmt: off
if not self.training: return bn_feat
# fmt: on
# Training
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(bn_feat)
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
else:
cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
F.normalize(self.classifier.weight))
# fmt: off
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
elif self.neck_feat == "after": feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
# fmt: on
return {
"cls_outputs": cls_outputs,
"pred_class_logits": pred_class_logits,
"features": feat,
}

View File

@ -4,8 +4,8 @@
@contact: sherlockliao01@gmail.com
"""
from torch import nn
import torch.nn.functional as F
from torch import nn
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
@ -13,47 +13,53 @@ from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class ReductionHead(nn.Module):
class EmbeddingHead(nn.Module):
def __init__(self, cfg):
super().__init__()
# fmt: off
in_feat = cfg.MODEL.HEADS.IN_FEAT
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
cls_type = cfg.MODEL.HEADS.CLS_LAYER
with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
norm_type = cfg.MODEL.HEADS.NORM
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
elif pool_type == 'gempoolP': self.pool_layer = GeneralizedMeanPoolingP()
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPooling()
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
elif pool_type == "identity": self.pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
elif pool_type == "flatten": self.pool_layer = Flatten()
else: raise KeyError(f"{pool_type} is not supported!")
# fmt: on
self.bottleneck = nn.Sequential(
nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True),
)
self.neck_feat = neck_feat
self.bottleneck.apply(weights_init_kaiming)
bottleneck = []
if embedding_dim > 0:
bottleneck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
feat_dim = embedding_dim
if with_bnneck:
bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
self.bottleneck = nn.Sequential(*bottleneck)
# identity classification layer
cls_type = cfg.MODEL.HEADS.CLS_LAYER
# fmt: off
if cls_type == 'linear': self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, reduction_dim, num_classes)
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, reduction_dim, num_classes)
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, reduction_dim, num_classes)
else:
raise KeyError(f"{cls_type} is invalid, please choose from "
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, feat_dim, num_classes)
else: raise KeyError(f"{cls_type} is not supported!")
# fmt: on
self.bottleneck.apply(weights_init_kaiming)
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
@ -81,8 +87,7 @@ class ReductionHead(nn.Module):
# fmt: off
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
elif self.neck_feat == "after": feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
else: raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
# fmt: on
return {

View File

@ -1,75 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from torch import nn
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_classifier
from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class LinearHead(nn.Module):
def __init__(self, cfg):
super().__init__()
# fmt: off
in_feat = cfg.MODEL.HEADS.IN_FEAT
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
elif pool_type == "identity": self.pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
# identity classification layer
cls_type = cfg.MODEL.HEADS.CLS_LAYER
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, in_feat, num_classes)
else:
raise KeyError(f"{cls_type} is invalid, please choose from "
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
# fmt: on
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
"""
See :class:`ReIDHeads.forward`.
"""
global_feat = self.pool_layer(features)
global_feat = global_feat[..., 0, 0]
# Evaluation
# fmt: off
if not self.training: return global_feat
# fmt: on
# Training
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(global_feat)
pred_class_logits = F.linear(global_feat, self.classifier.weight)
else:
cls_outputs = self.classifier(global_feat, targets)
pred_class_logits = self.classifier.s * F.linear(F.normalize(global_feat),
F.normalize(self.classifier.weight))
return {
"cls_outputs": cls_outputs,
"pred_class_logits": pred_class_logits,
"features": global_feat,
}

View File

@ -38,7 +38,7 @@ class Baseline(nn.Module):
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"].long().to(self.device)
targets = batched_inputs["targets"].to(self.device)
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will