refactor code

Summary: change code style and refactor code, add avgmax pooling layer in gem_pool
pull/68/head
liaoxingyu 2020-05-28 13:49:39 +08:00
parent d63a3ce47c
commit 5528d17ace
9 changed files with 66 additions and 65 deletions

View File

@ -13,7 +13,7 @@ from .non_local import Non_local
from .se_layer import SELayer
from .frn import FRN, TLU
from .activation import *
from .gem_pool import GeneralizedMeanPoolingP
from .gem_pool import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d
from .arcface import Arcface
from .circle import Circle
from .splat import SplAtConv2d

View File

@ -15,17 +15,13 @@ from fastreid.utils.one_hot import one_hot
class Arcface(nn.Module):
def __init__(self, cfg, in_feat):
def __init__(self, cfg, in_feat, num_classes):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self._num_classes = num_classes
self._s = cfg.MODEL.HEADS.SCALE
self._m = cfg.MODEL.HEADS.MARGIN
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, features, targets):
# get cos(theta)

View File

@ -15,17 +15,13 @@ from fastreid.utils.one_hot import one_hot
class Circle(nn.Module):
def __init__(self, cfg, in_feat):
def __init__(self, cfg, in_feat, num_classes):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self._num_classes = num_classes
self._s = cfg.MODEL.HEADS.SCALE
self._m = cfg.MODEL.HEADS.MARGIN
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, features, targets):
sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))

View File

@ -47,3 +47,16 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
def __init__(self, norm=3, output_size=1, eps=1e-6):
super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
self.p = nn.Parameter(torch.ones(1) * norm)
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.output_size = output_size
def forward(self, x):
x_max = F.adaptive_avg_pool2d(x, self.output_size)
x_avg = F.adaptive_max_pool2d(x, self.output_size)
x = x_max + x_avg
return x

View File

@ -20,16 +20,15 @@ class BNneckHead(nn.Module):
self.bnneck.apply(weights_init_kaiming)
# identity classification layer
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
self.classifier = Arcface(cfg, in_feat)
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
self.classifier = Circle(cfg, in_feat)
cls_type = cfg.MODEL.HEADS.CLS_LAYER
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
elif cls_type == 'arcface': self.classifier = Arcface(cfg, in_feat, num_classes)
elif cls_type == 'circle': self.classifier = Circle(cfg, in_feat, num_classes)
else:
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
raise KeyError(f"{cls_type} is invalid, please choose from "
f"'linear', 'arcface' and 'circle'.")
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
"""
@ -39,18 +38,13 @@ class BNneckHead(nn.Module):
bn_feat = self.bnneck(global_feat)
bn_feat = bn_feat[..., 0, 0]
# Evaluation
if not self.training:
return bn_feat
if not self.training: return bn_feat
# Training
try:
pred_class_logits = self.classifier(bn_feat)
except TypeError:
pred_class_logits = self.classifier(bn_feat, targets)
try: pred_class_logits = self.classifier(bn_feat)
except TypeError: pred_class_logits = self.classifier(bn_feat, targets)
if self.neck_feat == "before":
feat = global_feat[..., 0, 0]
elif self.neck_feat == "after":
feat = bn_feat
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
elif self.neck_feat == "after": feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
return pred_class_logits, feat, targets

View File

@ -6,6 +6,7 @@
from fastreid.layers import *
from .build import REID_HEADS_REGISTRY
from fastreid.utils.weight_init import weights_init_classifier
@REID_HEADS_REGISTRY.register()
@ -15,14 +16,15 @@ class LinearHead(nn.Module):
self.pool_layer = pool_layer
# identity classification layer
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
self.classifier = Arcface(cfg, in_feat)
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
self.classifier = Circle(cfg, in_feat)
cls_type = cfg.MODEL.HEADS.CLS_LAYER
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
elif cls_type == 'arcface': self.classifier = Arcface(cfg, in_feat, num_classes)
elif cls_type == 'circle': self.classifier = Circle(cfg, in_feat, num_classes)
else:
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
raise KeyError(f"{cls_type} is invalid, please choose from "
f"'linear', 'arcface' and 'circle'.")
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
"""
@ -30,11 +32,8 @@ class LinearHead(nn.Module):
"""
global_feat = self.pool_layer(features)
global_feat = global_feat[..., 0, 0]
if not self.training:
return global_feat
if not self.training: return global_feat
# training
try:
pred_class_logits = self.classifier(global_feat)
except TypeError:
pred_class_logits = self.classifier(global_feat, targets)
try: pred_class_logits = self.classifier(global_feat)
except TypeError: pred_class_logits = self.classifier(global_feat, targets)
return pred_class_logits, global_feat, targets

View File

@ -7,7 +7,7 @@
import torch
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.heads import build_reid_heads
from fastreid.modeling.losses import reid_losses
@ -25,14 +25,15 @@ class Baseline(nn.Module):
self.backbone = build_backbone(cfg)
# head
if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
pool_layer = nn.AdaptiveAvgPool2d(1)
elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
pool_layer = nn.AdaptiveMaxPool2d(1)
elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
pool_layer = GeneralizedMeanPoolingP()
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d(1)
elif pool_type == "identity": pool_layer = nn.Identity()
else:
pool_layer = nn.Identity()
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.")
in_feat = cfg.MODEL.HEADS.IN_FEAT
num_classes = cfg.MODEL.HEADS.NUM_CLASSES

View File

@ -8,7 +8,7 @@ import copy
import torch
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP, get_norm
from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPool2d
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.backbones.resnet import Bottleneck
from fastreid.modeling.heads import build_reid_heads
@ -50,14 +50,15 @@ class MGN(nn.Module):
Bottleneck(2048, 512, bn_norm, num_splits, False, with_se))
res_p_conv5.load_state_dict(backbone.layer4.state_dict())
if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
pool_layer = nn.AdaptiveAvgPool2d(1)
elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
pool_layer = nn.AdaptiveMaxPool2d(1)
elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
pool_layer = GeneralizedMeanPoolingP()
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d(1)
elif pool_type == "identity": pool_layer = nn.Identity()
else:
pool_layer = nn.Identity()
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.")
# head
in_feat = cfg.MODEL.HEADS.IN_FEAT
@ -117,10 +118,8 @@ class MGN(nn.Module):
def forward(self, batched_inputs):
if not self.training:
pred_feat = self.inference(batched_inputs)
try:
return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except KeyError:
return pred_feat
try: return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except KeyError: return pred_feat
images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()

View File

@ -4,6 +4,7 @@
@contact: liaoxingyu5@jd.com
"""
import math
from torch import nn
__all__ = [
@ -34,3 +35,5 @@ def weights_init_classifier(m):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find("Arcface") and classname.find("Circle") != -1:
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))