mirror of https://github.com/JDAI-CV/fast-reid.git
update fast global avgpool
Summary: update fast pool according to https://arxiv.org/pdf/2003.13630.pdfpull/150/head
parent
cbdc01a1c3
commit
56a1ab4a5d
|
@ -86,7 +86,7 @@ _C.MODEL.LOSSES.CE = CN()
|
|||
# if epsilon == 0, it means no label smooth regularization,
|
||||
# if epsilon == -1, it means adaptive label smooth regularization
|
||||
_C.MODEL.LOSSES.CE.EPSILON = 0.0
|
||||
_C.MODEL.LOSSES.CE.ALPHA = 0.3
|
||||
_C.MODEL.LOSSES.CE.ALPHA = 0.2
|
||||
_C.MODEL.LOSSES.CE.SCALE = 1.0
|
||||
|
||||
# Triplet Loss options
|
||||
|
@ -100,7 +100,8 @@ _C.MODEL.LOSSES.TRI.SCALE = 1.0
|
|||
# Circle Loss options
|
||||
_C.MODEL.LOSSES.CIRCLE = CN()
|
||||
_C.MODEL.LOSSES.CIRCLE.MARGIN = 0.25
|
||||
_C.MODEL.LOSSES.CIRCLE.SCALE = 128
|
||||
_C.MODEL.LOSSES.CIRCLE.ALPHA = 128
|
||||
_C.MODEL.LOSSES.CIRCLE.SCALE = 1.0
|
||||
|
||||
# Focal Loss options
|
||||
_C.MODEL.LOSSES.FL = CN()
|
||||
|
|
|
@ -3,22 +3,15 @@
|
|||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from torch import nn
|
||||
|
||||
from .batch_drop import BatchDrop
|
||||
from .attention import *
|
||||
from .batch_norm import *
|
||||
from .context_block import ContextBlock
|
||||
from .non_local import Non_local
|
||||
from .se_layer import SELayer
|
||||
from .frn import FRN, TLU
|
||||
from .activation import *
|
||||
from .gem_pool import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d
|
||||
from .arcface import Arcface
|
||||
from .batch_drop import BatchDrop
|
||||
from .batch_norm import *
|
||||
from .circle import Circle
|
||||
from .context_block import ContextBlock
|
||||
from .frn import FRN, TLU
|
||||
from .non_local import Non_local
|
||||
from .pooling import *
|
||||
from .se_layer import SELayer
|
||||
from .splat import SplAtConv2d
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
|
|
@ -1,177 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: CASIA IVA
|
||||
@contact: jliu@nlpr.ia.ac.cn
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn import Module, Conv2d, Parameter, Softmax
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['PAM_Module', 'CAM_Module', 'DANetHead',]
|
||||
|
||||
|
||||
class DANetHead(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_layer: nn.Module,
|
||||
module_class: type,
|
||||
dim_collapsion: int=2):
|
||||
super(DANetHead, self).__init__()
|
||||
|
||||
inter_channels = in_channels // dim_collapsion
|
||||
|
||||
self.conv5c = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
inter_channels,
|
||||
3,
|
||||
padding=1,
|
||||
bias=False
|
||||
),
|
||||
norm_layer(inter_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.attention_module = module_class(inter_channels)
|
||||
self.conv52 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inter_channels,
|
||||
inter_channels,
|
||||
3,
|
||||
padding=1,
|
||||
bias=False
|
||||
),
|
||||
norm_layer(inter_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv7 = nn.Sequential(
|
||||
nn.Dropout2d(0.1, False),
|
||||
nn.Conv2d(inter_channels, out_channels, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
feat2 = self.conv5c(x)
|
||||
sc_feat = self.attention_module(feat2)
|
||||
sc_conv = self.conv52(sc_feat)
|
||||
sc_output = self.conv7(sc_conv)
|
||||
|
||||
return sc_output
|
||||
|
||||
|
||||
class PAM_Module(nn.Module):
|
||||
""" Position attention module"""
|
||||
# Ref from SAGAN
|
||||
|
||||
def __init__(self, in_dim):
|
||||
super(PAM_Module, self).__init__()
|
||||
self.channel_in = in_dim
|
||||
|
||||
self.query_conv = Conv2d(
|
||||
in_channels=in_dim,
|
||||
out_channels=in_dim // 8,
|
||||
kernel_size=1
|
||||
)
|
||||
self.key_conv = Conv2d(
|
||||
in_channels=in_dim,
|
||||
out_channels=in_dim // 8,
|
||||
kernel_size=1
|
||||
)
|
||||
self.value_conv = Conv2d(
|
||||
in_channels=in_dim,
|
||||
out_channels=in_dim,
|
||||
kernel_size=1
|
||||
)
|
||||
self.gamma = Parameter(torch.zeros(1))
|
||||
|
||||
self.softmax = Softmax(dim=-1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
inputs :
|
||||
x : input feature maps( B X C X H X W)
|
||||
returns :
|
||||
out : attention value + input feature
|
||||
attention: B X (HxW) X (HxW)
|
||||
"""
|
||||
m_batchsize, C, height, width = x.size()
|
||||
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
|
||||
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
attention = self.softmax(energy)
|
||||
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
|
||||
|
||||
out = torch.bmm(
|
||||
proj_value,
|
||||
attention.permute(0, 2, 1)
|
||||
)
|
||||
attention_mask = out.view(m_batchsize, C, height, width)
|
||||
|
||||
out = self.gamma * attention_mask + x
|
||||
return out
|
||||
|
||||
|
||||
class CAM_Module(nn.Module):
|
||||
""" Channel attention module"""
|
||||
|
||||
def __init__(self, in_dim):
|
||||
super().__init__()
|
||||
self.channel_in = in_dim
|
||||
|
||||
self.gamma = Parameter(torch.zeros(1))
|
||||
self.softmax = Softmax(dim=-1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
inputs :
|
||||
x : input feature maps( B X C X H X W)
|
||||
returns :
|
||||
out : attention value + input feature
|
||||
attention: B X C X C
|
||||
"""
|
||||
m_batchsize, C, height, width = x.size()
|
||||
proj_query = x.view(m_batchsize, C, -1)
|
||||
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
max_energy_0 = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)
|
||||
energy_new = max_energy_0 - energy
|
||||
attention = self.softmax(energy_new)
|
||||
proj_value = x.view(m_batchsize, C, -1)
|
||||
|
||||
out = torch.bmm(attention, proj_value)
|
||||
out = out.view(m_batchsize, C, height, width)
|
||||
|
||||
gamma = self.gamma.to(out.device)
|
||||
out = gamma * out + x
|
||||
return out
|
||||
|
||||
|
||||
# def get_attention_module_instance(
|
||||
# name: 'cam | pam | identity',
|
||||
# dim: int,
|
||||
# *,
|
||||
# out_dim=None,
|
||||
# use_head: bool=False,
|
||||
# dim_collapsion=2 # Used iff `used_head` set to True
|
||||
# ):
|
||||
#
|
||||
# name = name.lower()
|
||||
# assert name in ('cam', 'pam', 'identity')
|
||||
#
|
||||
# module_class = name_module_class_mapping[name]
|
||||
#
|
||||
# if out_dim is None:
|
||||
# out_dim = dim
|
||||
#
|
||||
# if use_head:
|
||||
# return DANetHead(
|
||||
# dim, out_dim,
|
||||
# nn.BatchNorm2d,
|
||||
# module_class,
|
||||
# dim_collapsion=dim_collapsion
|
||||
# )
|
||||
# else:
|
||||
# return module_class(dim)
|
|
@ -9,6 +9,11 @@ import torch.nn.functional as F
|
|||
from torch import nn
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
class GeneralizedMeanPooling(nn.Module):
|
||||
r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
|
||||
The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
|
||||
|
@ -50,13 +55,25 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
|
|||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size):
|
||||
def __init__(self):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
self.avgpool = FastGlobalAvgPool2d()
|
||||
|
||||
def forward(self, x):
|
||||
x_max = F.adaptive_avg_pool2d(x, self.output_size)
|
||||
x_avg = F.adaptive_max_pool2d(x, self.output_size)
|
||||
x_avg = self.avgpool(x, self.output_size)
|
||||
x_max = F.adaptive_max_pool2d(x, 1)
|
||||
x = x_max + x_avg
|
||||
return x
|
||||
|
||||
|
||||
class FastGlobalAvgPool2d(nn.Module):
|
||||
def __init__(self, flatten=False):
|
||||
super(FastGlobalAvgPool2d, self).__init__()
|
||||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
if self.flatten:
|
||||
in_size = x.size()
|
||||
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
||||
else:
|
||||
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
|
@ -11,7 +11,7 @@ from .build import REID_HEADS_REGISTRY
|
|||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class BNneckHead(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
def __init__(self, cfg, in_feat, num_classes, pool_layer):
|
||||
super().__init__()
|
||||
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
|
||||
self.pool_layer = pool_layer
|
||||
|
|
|
@ -11,7 +11,7 @@ from fastreid.utils.weight_init import weights_init_classifier
|
|||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class LinearHead(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
def __init__(self, cfg, in_feat, num_classes, pool_layer):
|
||||
super().__init__()
|
||||
self.pool_layer = pool_layer
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from .build import REID_HEADS_REGISTRY
|
|||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class ReductionHead(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
def __init__(self, cfg, in_feat, num_classes, pool_layer):
|
||||
super().__init__()
|
||||
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
|
||||
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
|
||||
|
|
|
@ -167,10 +167,10 @@ class CircleLoss(object):
|
|||
self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
|
||||
|
||||
self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
|
||||
self.s = cfg.MODEL.LOSSES.CIRCLE.SCALE
|
||||
self.s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
|
||||
|
||||
def __call__(self, _, global_features, targets):
|
||||
global_features = normalize(global_features, axis=-1)
|
||||
global_features = F.normalize(global_features, dim=1)
|
||||
|
||||
sim_mat = torch.matmul(global_features, global_features.t())
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d
|
||||
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
|
||||
from fastreid.modeling.backbones import build_backbone
|
||||
from fastreid.modeling.heads import build_reid_heads
|
||||
from fastreid.modeling.losses import reid_losses
|
||||
|
@ -26,10 +26,10 @@ class Baseline(nn.Module):
|
|||
|
||||
# head
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
if pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d()
|
||||
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d(1)
|
||||
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
|
||||
elif pool_type == "identity": pool_layer = nn.Identity()
|
||||
else:
|
||||
raise KeyError(f"{pool_type} is invalid, please choose from "
|
||||
|
|
|
@ -8,7 +8,7 @@ import copy
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPool2d
|
||||
from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
|
||||
from fastreid.modeling.backbones import build_backbone
|
||||
from fastreid.modeling.backbones.resnet import Bottleneck
|
||||
from fastreid.modeling.heads import build_reid_heads
|
||||
|
@ -51,10 +51,10 @@ class MGN(nn.Module):
|
|||
res_p_conv5.load_state_dict(backbone.layer4.state_dict())
|
||||
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
if pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d()
|
||||
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d(1)
|
||||
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
|
||||
elif pool_type == "identity": pool_layer = nn.Identity()
|
||||
else:
|
||||
raise KeyError(f"{pool_type} is invalid, please choose from "
|
||||
|
|
Loading…
Reference in New Issue