update fast global avgpool

Summary: update fast pool according to https://arxiv.org/pdf/2003.13630.pdf
pull/150/head
liaoxingyu 2020-06-12 16:34:03 +08:00
parent cbdc01a1c3
commit 56a1ab4a5d
10 changed files with 42 additions and 208 deletions

View File

@ -86,7 +86,7 @@ _C.MODEL.LOSSES.CE = CN()
# if epsilon == 0, it means no label smooth regularization,
# if epsilon == -1, it means adaptive label smooth regularization
_C.MODEL.LOSSES.CE.EPSILON = 0.0
_C.MODEL.LOSSES.CE.ALPHA = 0.3
_C.MODEL.LOSSES.CE.ALPHA = 0.2
_C.MODEL.LOSSES.CE.SCALE = 1.0
# Triplet Loss options
@ -100,7 +100,8 @@ _C.MODEL.LOSSES.TRI.SCALE = 1.0
# Circle Loss options
_C.MODEL.LOSSES.CIRCLE = CN()
_C.MODEL.LOSSES.CIRCLE.MARGIN = 0.25
_C.MODEL.LOSSES.CIRCLE.SCALE = 128
_C.MODEL.LOSSES.CIRCLE.ALPHA = 128
_C.MODEL.LOSSES.CIRCLE.SCALE = 1.0
# Focal Loss options
_C.MODEL.LOSSES.FL = CN()

View File

@ -3,22 +3,15 @@
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
from .batch_drop import BatchDrop
from .attention import *
from .batch_norm import *
from .context_block import ContextBlock
from .non_local import Non_local
from .se_layer import SELayer
from .frn import FRN, TLU
from .activation import *
from .gem_pool import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d
from .arcface import Arcface
from .batch_drop import BatchDrop
from .batch_norm import *
from .circle import Circle
from .context_block import ContextBlock
from .frn import FRN, TLU
from .non_local import Non_local
from .pooling import *
from .se_layer import SELayer
from .splat import SplAtConv2d
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)

View File

@ -1,177 +0,0 @@
# encoding: utf-8
"""
@author: CASIA IVA
@contact: jliu@nlpr.ia.ac.cn
"""
import torch
from torch.nn import Module, Conv2d, Parameter, Softmax
import torch.nn as nn
__all__ = ['PAM_Module', 'CAM_Module', 'DANetHead',]
class DANetHead(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
norm_layer: nn.Module,
module_class: type,
dim_collapsion: int=2):
super(DANetHead, self).__init__()
inter_channels = in_channels // dim_collapsion
self.conv5c = nn.Sequential(
nn.Conv2d(
in_channels,
inter_channels,
3,
padding=1,
bias=False
),
norm_layer(inter_channels),
nn.ReLU()
)
self.attention_module = module_class(inter_channels)
self.conv52 = nn.Sequential(
nn.Conv2d(
inter_channels,
inter_channels,
3,
padding=1,
bias=False
),
norm_layer(inter_channels),
nn.ReLU()
)
self.conv7 = nn.Sequential(
nn.Dropout2d(0.1, False),
nn.Conv2d(inter_channels, out_channels, 1)
)
def forward(self, x):
feat2 = self.conv5c(x)
sc_feat = self.attention_module(feat2)
sc_conv = self.conv52(sc_feat)
sc_output = self.conv7(sc_conv)
return sc_output
class PAM_Module(nn.Module):
""" Position attention module"""
# Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.channel_in = in_dim
self.query_conv = Conv2d(
in_channels=in_dim,
out_channels=in_dim // 8,
kernel_size=1
)
self.key_conv = Conv2d(
in_channels=in_dim,
out_channels=in_dim // 8,
kernel_size=1
)
self.value_conv = Conv2d(
in_channels=in_dim,
out_channels=in_dim,
kernel_size=1
)
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
out = torch.bmm(
proj_value,
attention.permute(0, 2, 1)
)
attention_mask = out.view(m_batchsize, C, height, width)
out = self.gamma * attention_mask + x
return out
class CAM_Module(nn.Module):
""" Channel attention module"""
def __init__(self, in_dim):
super().__init__()
self.channel_in = in_dim
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X C X C
"""
m_batchsize, C, height, width = x.size()
proj_query = x.view(m_batchsize, C, -1)
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
max_energy_0 = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)
energy_new = max_energy_0 - energy
attention = self.softmax(energy_new)
proj_value = x.view(m_batchsize, C, -1)
out = torch.bmm(attention, proj_value)
out = out.view(m_batchsize, C, height, width)
gamma = self.gamma.to(out.device)
out = gamma * out + x
return out
# def get_attention_module_instance(
# name: 'cam | pam | identity',
# dim: int,
# *,
# out_dim=None,
# use_head: bool=False,
# dim_collapsion=2 # Used iff `used_head` set to True
# ):
#
# name = name.lower()
# assert name in ('cam', 'pam', 'identity')
#
# module_class = name_module_class_mapping[name]
#
# if out_dim is None:
# out_dim = dim
#
# if use_head:
# return DANetHead(
# dim, out_dim,
# nn.BatchNorm2d,
# module_class,
# dim_collapsion=dim_collapsion
# )
# else:
# return module_class(dim)

View File

@ -9,6 +9,11 @@ import torch.nn.functional as F
from torch import nn
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
class GeneralizedMeanPooling(nn.Module):
r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
@ -50,13 +55,25 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size):
def __init__(self):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.output_size = output_size
self.avgpool = FastGlobalAvgPool2d()
def forward(self, x):
x_max = F.adaptive_avg_pool2d(x, self.output_size)
x_avg = F.adaptive_max_pool2d(x, self.output_size)
x_avg = self.avgpool(x, self.output_size)
x_max = F.adaptive_max_pool2d(x, 1)
x = x_max + x_avg
return x
class FastGlobalAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastGlobalAvgPool2d, self).__init__()
self.flatten = flatten
def forward(self, x):
if self.flatten:
in_size = x.size()
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
else:
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)

View File

@ -11,7 +11,7 @@ from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class BNneckHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
def __init__(self, cfg, in_feat, num_classes, pool_layer):
super().__init__()
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
self.pool_layer = pool_layer

View File

@ -11,7 +11,7 @@ from fastreid.utils.weight_init import weights_init_classifier
@REID_HEADS_REGISTRY.register()
class LinearHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
def __init__(self, cfg, in_feat, num_classes, pool_layer):
super().__init__()
self.pool_layer = pool_layer

View File

@ -11,7 +11,7 @@ from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class ReductionHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
def __init__(self, cfg, in_feat, num_classes, pool_layer):
super().__init__()
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT

View File

@ -167,10 +167,10 @@ class CircleLoss(object):
self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
self.s = cfg.MODEL.LOSSES.CIRCLE.SCALE
self.s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
def __call__(self, _, global_features, targets):
global_features = normalize(global_features, axis=-1)
global_features = F.normalize(global_features, dim=1)
sim_mat = torch.matmul(global_features, global_features.t())

View File

@ -7,7 +7,7 @@
import torch
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.heads import build_reid_heads
from fastreid.modeling.losses import reid_losses
@ -26,10 +26,10 @@ class Baseline(nn.Module):
# head
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d(1)
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == "identity": pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "

View File

@ -8,7 +8,7 @@ import copy
import torch
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPool2d
from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.backbones.resnet import Bottleneck
from fastreid.modeling.heads import build_reid_heads
@ -51,10 +51,10 @@ class MGN(nn.Module):
res_p_conv5.load_state_dict(backbone.layer4.state_dict())
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d(1)
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == "identity": pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "