# copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py import torch from torch import nn __all__ = ['ContextBlock'] def last_zero_init(m): if isinstance(m, nn.Sequential): nn.init.constant_(m[-1].weight, val=0) if hasattr(m[-1], 'bias') and m[-1].bias is not None: nn.init.constant_(m[-1].bias, 0) else: nn.init.constant_(m.weight, val=0) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0) class ContextBlock(nn.Module): def __init__(self, inplanes, ratio, pooling_type='att', fusion_types=('channel_add', )): super(ContextBlock, self).__init__() assert pooling_type in ['avg', 'att'] assert isinstance(fusion_types, (list, tuple)) valid_fusion_types = ['channel_add', 'channel_mul'] assert all([f in valid_fusion_types for f in fusion_types]) assert len(fusion_types) > 0, 'at least one fusion should be used' self.inplanes = inplanes self.ratio = ratio self.planes = int(inplanes * ratio) self.pooling_type = pooling_type self.fusion_types = fusion_types if pooling_type == 'att': self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) self.softmax = nn.Softmax(dim=2) else: self.avg_pool = nn.AdaptiveAvgPool2d(1) if 'channel_add' in fusion_types: self.channel_add_conv = nn.Sequential( nn.Conv2d(self.inplanes, self.planes, kernel_size=1), nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True), # yapf: disable nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) else: self.channel_add_conv = None if 'channel_mul' in fusion_types: self.channel_mul_conv = nn.Sequential( nn.Conv2d(self.inplanes, self.planes, kernel_size=1), nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True), # yapf: disable nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) else: self.channel_mul_conv = None self.reset_parameters() def reset_parameters(self): if self.pooling_type == 'att': nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu') if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None: nn.init.constant_(self.conv_mask.bias, 0) self.conv_mask.inited = True if self.channel_add_conv is not None: last_zero_init(self.channel_add_conv) if self.channel_mul_conv is not None: last_zero_init(self.channel_mul_conv) def spatial_pool(self, x): batch, channel, height, width = x.size() if self.pooling_type == 'att': input_x = x # [N, C, H * W] input_x = input_x.view(batch, channel, height * width) # [N, 1, C, H * W] input_x = input_x.unsqueeze(1) # [N, 1, H, W] context_mask = self.conv_mask(x) # [N, 1, H * W] context_mask = context_mask.view(batch, 1, height * width) # [N, 1, H * W] context_mask = self.softmax(context_mask) # [N, 1, H * W, 1] context_mask = context_mask.unsqueeze(-1) # [N, 1, C, 1] context = torch.matmul(input_x, context_mask) # [N, C, 1, 1] context = context.view(batch, channel, 1, 1) else: # [N, C, 1, 1] context = self.avg_pool(x) return context def forward(self, x): # [N, C, 1, 1] context = self.spatial_pool(x) out = x if self.channel_mul_conv is not None: # [N, C, 1, 1] channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) out = out * channel_mul_term if self.channel_add_conv is not None: # [N, C, 1, 1] channel_add_term = self.channel_add_conv(context) out = out + channel_add_term return out