51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
from mmcv.cnn import NonLocal2d
|
|
|
|
from ..builder import HEADS
|
|
from .fcn_head import FCNHead
|
|
|
|
|
|
@HEADS.register_module()
|
|
class NLHead(FCNHead):
|
|
"""Non-local Neural Networks.
|
|
|
|
This head is the implementation of `NLNet
|
|
<https://arxiv.org/abs/1711.07971>`_.
|
|
|
|
Args:
|
|
reduction (int): Reduction factor of projection transform. Default: 2.
|
|
use_scale (bool): Whether to scale pairwise_weight by
|
|
sqrt(1/inter_channels). Default: True.
|
|
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
|
'dot_product'. Default: 'embedded_gaussian.'.
|
|
"""
|
|
|
|
def __init__(self,
|
|
reduction=2,
|
|
use_scale=True,
|
|
mode='embedded_gaussian',
|
|
**kwargs):
|
|
super(NLHead, self).__init__(num_convs=2, **kwargs)
|
|
self.reduction = reduction
|
|
self.use_scale = use_scale
|
|
self.mode = mode
|
|
self.nl_block = NonLocal2d(
|
|
in_channels=self.channels,
|
|
reduction=self.reduction,
|
|
use_scale=self.use_scale,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
mode=self.mode)
|
|
|
|
def forward(self, inputs):
|
|
"""Forward function."""
|
|
x = self._transform_inputs(inputs)
|
|
output = self.convs[0](x)
|
|
output = self.nl_block(output)
|
|
output = self.convs[1](output)
|
|
if self.concat_input:
|
|
output = self.conv_cat(torch.cat([x, output], dim=1))
|
|
output = self.cls_seg(output)
|
|
return output
|