mirror of https://github.com/JDAI-CV/fast-reid.git
refactor code
Summary: change code style and refactor code, add avgmax pooling layer in gem_poolpull/68/head
parent
d63a3ce47c
commit
5528d17ace
|
@ -13,7 +13,7 @@ from .non_local import Non_local
|
|||
from .se_layer import SELayer
|
||||
from .frn import FRN, TLU
|
||||
from .activation import *
|
||||
from .gem_pool import GeneralizedMeanPoolingP
|
||||
from .gem_pool import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d
|
||||
from .arcface import Arcface
|
||||
from .circle import Circle
|
||||
from .splat import SplAtConv2d
|
||||
|
|
|
@ -15,17 +15,13 @@ from fastreid.utils.one_hot import one_hot
|
|||
|
||||
|
||||
class Arcface(nn.Module):
|
||||
def __init__(self, cfg, in_feat):
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self._num_classes = num_classes
|
||||
self._s = cfg.MODEL.HEADS.SCALE
|
||||
self._m = cfg.MODEL.HEADS.MARGIN
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets):
|
||||
# get cos(theta)
|
||||
|
|
|
@ -15,17 +15,13 @@ from fastreid.utils.one_hot import one_hot
|
|||
|
||||
|
||||
class Circle(nn.Module):
|
||||
def __init__(self, cfg, in_feat):
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self._num_classes = num_classes
|
||||
self._s = cfg.MODEL.HEADS.SCALE
|
||||
self._m = cfg.MODEL.HEADS.MARGIN
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets):
|
||||
sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
|
|
|
@ -47,3 +47,16 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
|
|||
def __init__(self, norm=3, output_size=1, eps=1e-6):
|
||||
super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
|
||||
self.p = nn.Parameter(torch.ones(1) * norm)
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
def forward(self, x):
|
||||
x_max = F.adaptive_avg_pool2d(x, self.output_size)
|
||||
x_avg = F.adaptive_max_pool2d(x, self.output_size)
|
||||
x = x_max + x_avg
|
||||
return x
|
||||
|
||||
|
|
|
@ -20,16 +20,15 @@ class BNneckHead(nn.Module):
|
|||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
# identity classification layer
|
||||
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
|
||||
self.classifier = Arcface(cfg, in_feat)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
|
||||
self.classifier = Circle(cfg, in_feat)
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
elif cls_type == 'arcface': self.classifier = Arcface(cfg, in_feat, num_classes)
|
||||
elif cls_type == 'circle': self.classifier = Circle(cfg, in_feat, num_classes)
|
||||
else:
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
raise KeyError(f"{cls_type} is invalid, please choose from "
|
||||
f"'linear', 'arcface' and 'circle'.")
|
||||
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
|
@ -39,18 +38,13 @@ class BNneckHead(nn.Module):
|
|||
bn_feat = self.bnneck(global_feat)
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
# Evaluation
|
||||
if not self.training:
|
||||
return bn_feat
|
||||
if not self.training: return bn_feat
|
||||
# Training
|
||||
try:
|
||||
pred_class_logits = self.classifier(bn_feat)
|
||||
except TypeError:
|
||||
pred_class_logits = self.classifier(bn_feat, targets)
|
||||
try: pred_class_logits = self.classifier(bn_feat)
|
||||
except TypeError: pred_class_logits = self.classifier(bn_feat, targets)
|
||||
|
||||
if self.neck_feat == "before":
|
||||
feat = global_feat[..., 0, 0]
|
||||
elif self.neck_feat == "after":
|
||||
feat = bn_feat
|
||||
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
|
||||
elif self.neck_feat == "after": feat = bn_feat
|
||||
else:
|
||||
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
|
||||
return pred_class_logits, feat, targets
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from fastreid.layers import *
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from fastreid.utils.weight_init import weights_init_classifier
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
|
@ -15,14 +16,15 @@ class LinearHead(nn.Module):
|
|||
self.pool_layer = pool_layer
|
||||
|
||||
# identity classification layer
|
||||
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
|
||||
self.classifier = Arcface(cfg, in_feat)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
|
||||
self.classifier = Circle(cfg, in_feat)
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
elif cls_type == 'arcface': self.classifier = Arcface(cfg, in_feat, num_classes)
|
||||
elif cls_type == 'circle': self.classifier = Circle(cfg, in_feat, num_classes)
|
||||
else:
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
raise KeyError(f"{cls_type} is invalid, please choose from "
|
||||
f"'linear', 'arcface' and 'circle'.")
|
||||
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
|
@ -30,11 +32,8 @@ class LinearHead(nn.Module):
|
|||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
global_feat = global_feat[..., 0, 0]
|
||||
if not self.training:
|
||||
return global_feat
|
||||
if not self.training: return global_feat
|
||||
# training
|
||||
try:
|
||||
pred_class_logits = self.classifier(global_feat)
|
||||
except TypeError:
|
||||
pred_class_logits = self.classifier(global_feat, targets)
|
||||
try: pred_class_logits = self.classifier(global_feat)
|
||||
except TypeError: pred_class_logits = self.classifier(global_feat, targets)
|
||||
return pred_class_logits, global_feat, targets
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import GeneralizedMeanPoolingP
|
||||
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d
|
||||
from fastreid.modeling.backbones import build_backbone
|
||||
from fastreid.modeling.heads import build_reid_heads
|
||||
from fastreid.modeling.losses import reid_losses
|
||||
|
@ -25,14 +25,15 @@ class Baseline(nn.Module):
|
|||
self.backbone = build_backbone(cfg)
|
||||
|
||||
# head
|
||||
if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
|
||||
pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
|
||||
pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
|
||||
pool_layer = GeneralizedMeanPoolingP()
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
if pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d(1)
|
||||
elif pool_type == "identity": pool_layer = nn.Identity()
|
||||
else:
|
||||
pool_layer = nn.Identity()
|
||||
raise KeyError(f"{pool_type} is invalid, please choose from "
|
||||
f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.")
|
||||
|
||||
in_feat = cfg.MODEL.HEADS.IN_FEAT
|
||||
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
|
|
@ -8,7 +8,7 @@ import copy
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import GeneralizedMeanPoolingP, get_norm
|
||||
from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPool2d
|
||||
from fastreid.modeling.backbones import build_backbone
|
||||
from fastreid.modeling.backbones.resnet import Bottleneck
|
||||
from fastreid.modeling.heads import build_reid_heads
|
||||
|
@ -50,14 +50,15 @@ class MGN(nn.Module):
|
|||
Bottleneck(2048, 512, bn_norm, num_splits, False, with_se))
|
||||
res_p_conv5.load_state_dict(backbone.layer4.state_dict())
|
||||
|
||||
if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
|
||||
pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
|
||||
pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
|
||||
pool_layer = GeneralizedMeanPoolingP()
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
if pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d(1)
|
||||
elif pool_type == "identity": pool_layer = nn.Identity()
|
||||
else:
|
||||
pool_layer = nn.Identity()
|
||||
raise KeyError(f"{pool_type} is invalid, please choose from "
|
||||
f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.")
|
||||
|
||||
# head
|
||||
in_feat = cfg.MODEL.HEADS.IN_FEAT
|
||||
|
@ -117,10 +118,8 @@ class MGN(nn.Module):
|
|||
def forward(self, batched_inputs):
|
||||
if not self.training:
|
||||
pred_feat = self.inference(batched_inputs)
|
||||
try:
|
||||
return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
|
||||
except KeyError:
|
||||
return pred_feat
|
||||
try: return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
|
||||
except KeyError: return pred_feat
|
||||
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
targets = batched_inputs["targets"].long()
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
import math
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
|
@ -34,3 +35,5 @@ def weights_init_classifier(m):
|
|||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find("Arcface") and classname.find("Circle") != -1:
|
||||
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
|
||||
|
|
Loading…
Reference in New Issue