mirror of https://github.com/JDAI-CV/fast-reid.git
refactor reid head
Summary: merge BNneckHead, LinearHead and ReductionHead into EmbeddingHead because they are highly similar and can be prepared for ClsHeadpull/299/head
parent
77caa01e34
commit
4d573b8107
|
@ -32,10 +32,10 @@ _C.MODEL.BACKBONE = CN()
|
|||
_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
|
||||
_C.MODEL.BACKBONE.DEPTH = "50x"
|
||||
_C.MODEL.BACKBONE.LAST_STRIDE = 1
|
||||
# Backbone feature dimension
|
||||
_C.MODEL.BACKBONE.FEAT_DIM = 2048
|
||||
# Normalization method for the convolution layers.
|
||||
_C.MODEL.BACKBONE.NORM = "BN"
|
||||
# Mini-batch split of Ghost BN
|
||||
_C.MODEL.BACKBONE.NORM_SPLIT = 1
|
||||
# If use IBN block in backbone
|
||||
_C.MODEL.BACKBONE.WITH_IBN = False
|
||||
# If use SE block in backbone
|
||||
|
@ -51,18 +51,15 @@ _C.MODEL.BACKBONE.PRETRAIN_PATH = ''
|
|||
# REID HEADS options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.MODEL.HEADS = CN()
|
||||
_C.MODEL.HEADS.NAME = "BNneckHead"
|
||||
|
||||
_C.MODEL.HEADS.NAME = "EmbeddingHead"
|
||||
# Normalization method for the convolution layers.
|
||||
_C.MODEL.HEADS.NORM = "BN"
|
||||
# Mini-batch split of Ghost BN
|
||||
_C.MODEL.HEADS.NORM_SPLIT = 1
|
||||
# Number of identity
|
||||
_C.MODEL.HEADS.NUM_CLASSES = 0
|
||||
# Input feature dimension
|
||||
_C.MODEL.HEADS.IN_FEAT = 2048
|
||||
# Reduction dimension in head
|
||||
_C.MODEL.HEADS.REDUCTION_DIM = 512
|
||||
# Embedding dimension in head
|
||||
_C.MODEL.HEADS.EMBEDDING_DIM = 0
|
||||
# If use BNneck in embedding
|
||||
_C.MODEL.HEADS.WITH_BNNECK = True
|
||||
# Triplet feature using feature before(after) bnneck
|
||||
_C.MODEL.HEADS.NECK_FEAT = "before" # options: before, after
|
||||
# Pooling layer type
|
||||
|
@ -274,4 +271,3 @@ _C.OUTPUT_DIR = "logs/"
|
|||
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
|
||||
# If input images have the same or similar sizes, benchmark is often helpful.
|
||||
_C.CUDNN_BENCHMARK = False
|
||||
|
||||
|
|
|
@ -373,8 +373,8 @@ class DefaultTrainer(SimpleTrainer):
|
|||
Overwrite it if you'd like a different model.
|
||||
"""
|
||||
model = build_model(cfg)
|
||||
# logger = logging.getLogger(__name__)
|
||||
# logger.info("Model:\n{}".format(model))
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Model:\n{}".format(model))
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -280,7 +280,7 @@ class AutogradProfiler(HookBase):
|
|||
self._profiler.export_chrome_trace(out_file)
|
||||
else:
|
||||
# Support non-posix filesystems
|
||||
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
|
||||
with tempfile.TemporaryDirectory(prefix="fastreid_profiler") as d:
|
||||
tmp_file = os.path.join(d, "tmp.json")
|
||||
self._profiler.export_chrome_trace(tmp_file)
|
||||
with open(tmp_file) as f:
|
||||
|
|
|
@ -7,6 +7,4 @@
|
|||
from .build import REID_HEADS_REGISTRY, build_reid_heads
|
||||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .linear_head import LinearHead
|
||||
from .bnneck_head import BNneckHead
|
||||
from .reduction_head import ReductionHead
|
||||
from .embedding_head import EmbeddingHead
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import *
|
||||
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class BNneckHead(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
# fmt: off
|
||||
in_feat = cfg.MODEL.HEADS.IN_FEAT
|
||||
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
|
||||
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
|
||||
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
|
||||
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
|
||||
elif pool_type == "identity": self.pool_layer = nn.Identity()
|
||||
else:
|
||||
raise KeyError(f"{pool_type} is invalid, please choose from "
|
||||
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
|
||||
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
|
||||
# fmt: on
|
||||
|
||||
self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
# identity classification layer
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
# fmt: off
|
||||
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
|
||||
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
|
||||
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, in_feat, num_classes)
|
||||
else:
|
||||
raise KeyError(f"{cls_type} is invalid, please choose from "
|
||||
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
|
||||
# fmt: on
|
||||
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
|
||||
# Evaluation
|
||||
# fmt: off
|
||||
if not self.training: return bn_feat
|
||||
# fmt: on
|
||||
|
||||
# Training
|
||||
if self.classifier.__class__.__name__ == 'Linear':
|
||||
cls_outputs = self.classifier(bn_feat)
|
||||
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
|
||||
else:
|
||||
cls_outputs = self.classifier(bn_feat, targets)
|
||||
pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
|
||||
F.normalize(self.classifier.weight))
|
||||
|
||||
# fmt: off
|
||||
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
|
||||
elif self.neck_feat == "after": feat = bn_feat
|
||||
else:
|
||||
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
|
||||
# fmt: on
|
||||
|
||||
return {
|
||||
"cls_outputs": cls_outputs,
|
||||
"pred_class_logits": pred_class_logits,
|
||||
"features": feat,
|
||||
}
|
|
@ -4,8 +4,8 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import *
|
||||
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
|
||||
|
@ -13,47 +13,53 @@ from .build import REID_HEADS_REGISTRY
|
|||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class ReductionHead(nn.Module):
|
||||
class EmbeddingHead(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
# fmt: off
|
||||
in_feat = cfg.MODEL.HEADS.IN_FEAT
|
||||
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
|
||||
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
|
||||
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
|
||||
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
neck_feat = cfg.MODEL.HEADS.NECK_FEAT
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
|
||||
norm_type = cfg.MODEL.HEADS.NORM
|
||||
|
||||
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
|
||||
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == 'gempoolP': self.pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPooling()
|
||||
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
|
||||
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
|
||||
elif pool_type == "identity": self.pool_layer = nn.Identity()
|
||||
else:
|
||||
raise KeyError(f"{pool_type} is invalid, please choose from "
|
||||
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
|
||||
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
|
||||
elif pool_type == "flatten": self.pool_layer = Flatten()
|
||||
else: raise KeyError(f"{pool_type} is not supported!")
|
||||
# fmt: on
|
||||
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
|
||||
get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True),
|
||||
)
|
||||
self.neck_feat = neck_feat
|
||||
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
bottleneck = []
|
||||
if embedding_dim > 0:
|
||||
bottleneck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
|
||||
feat_dim = embedding_dim
|
||||
|
||||
if with_bnneck:
|
||||
bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
|
||||
|
||||
self.bottleneck = nn.Sequential(*bottleneck)
|
||||
|
||||
# identity classification layer
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
# fmt: off
|
||||
if cls_type == 'linear': self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
|
||||
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, reduction_dim, num_classes)
|
||||
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, reduction_dim, num_classes)
|
||||
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, reduction_dim, num_classes)
|
||||
else:
|
||||
raise KeyError(f"{cls_type} is invalid, please choose from "
|
||||
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
|
||||
if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
|
||||
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, feat_dim, num_classes)
|
||||
else: raise KeyError(f"{cls_type} is not supported!")
|
||||
# fmt: on
|
||||
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
|
@ -81,8 +87,7 @@ class ReductionHead(nn.Module):
|
|||
# fmt: off
|
||||
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
|
||||
elif self.neck_feat == "after": feat = bn_feat
|
||||
else:
|
||||
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
|
||||
else: raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
|
||||
# fmt: on
|
||||
|
||||
return {
|
|
@ -1,75 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import *
|
||||
from fastreid.utils.weight_init import weights_init_classifier
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class LinearHead(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
# fmt: off
|
||||
in_feat = cfg.MODEL.HEADS.IN_FEAT
|
||||
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
|
||||
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
|
||||
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
|
||||
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
|
||||
elif pool_type == "identity": self.pool_layer = nn.Identity()
|
||||
else:
|
||||
raise KeyError(f"{pool_type} is invalid, please choose from "
|
||||
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
|
||||
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
|
||||
|
||||
# identity classification layer
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
|
||||
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
|
||||
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, in_feat, num_classes)
|
||||
else:
|
||||
raise KeyError(f"{cls_type} is invalid, please choose from "
|
||||
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
|
||||
# fmt: on
|
||||
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
global_feat = global_feat[..., 0, 0]
|
||||
|
||||
# Evaluation
|
||||
# fmt: off
|
||||
if not self.training: return global_feat
|
||||
# fmt: on
|
||||
|
||||
# Training
|
||||
if self.classifier.__class__.__name__ == 'Linear':
|
||||
cls_outputs = self.classifier(global_feat)
|
||||
pred_class_logits = F.linear(global_feat, self.classifier.weight)
|
||||
else:
|
||||
cls_outputs = self.classifier(global_feat, targets)
|
||||
pred_class_logits = self.classifier.s * F.linear(F.normalize(global_feat),
|
||||
F.normalize(self.classifier.weight))
|
||||
|
||||
return {
|
||||
"cls_outputs": cls_outputs,
|
||||
"pred_class_logits": pred_class_logits,
|
||||
"features": global_feat,
|
||||
}
|
|
@ -38,7 +38,7 @@ class Baseline(nn.Module):
|
|||
|
||||
if self.training:
|
||||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"].long().to(self.device)
|
||||
targets = batched_inputs["targets"].to(self.device)
|
||||
|
||||
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
|
||||
# may be larger than that in the original dataset, so the circle/arcface will
|
||||
|
|
Loading…
Reference in New Issue