[Feature] Support kenerl updation for some decoder heads. (#1299)

* [Feature] Add kenerl updation for some decoder heads.

* [Feature] Add kenerl updation for some decoder heads.

* directly use forward_feature && modify other 3 decoder heads

* remover kernel_update attr

* delete unnecessary variables in forward function

* delete kernel update function

* delete kernel update function

* delete unnecessary docstrings

* modify comments in self._forward_feature()

* modify docstrings in self._forward_feature()

* fix docstring

* modify uperhead
This commit is contained in:
MengzhangLI 2022-02-27 11:35:29 +08:00 committed by GitHub
parent c0442f1b4a
commit 4912ea2b0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 13 deletions

View File

@ -91,8 +91,17 @@ class ASPPHead(BaseDecodeHead):
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
@ -103,6 +112,11 @@ class ASPPHead(BaseDecodeHead):
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
feats = self.bottleneck(aspp_outs)
return feats
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output

View File

@ -72,11 +72,25 @@ class FCNHead(BaseDecodeHead):
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
x = self._transform_inputs(inputs)
feats = self.convs(x)
if self.concat_input:
feats = self.conv_cat(torch.cat([x, feats], dim=1))
return feats
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs(x)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output

View File

@ -92,12 +92,26 @@ class PSPHead(BaseDecodeHead):
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
x = self._transform_inputs(inputs)
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
feats = self.bottleneck(psp_outs)
return feats
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output

View File

@ -84,9 +84,17 @@ class UPerHead(BaseDecodeHead):
return output
def forward(self, inputs):
"""Forward function."""
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
inputs = self._transform_inputs(inputs)
# build laterals
@ -122,6 +130,11 @@ class UPerHead(BaseDecodeHead):
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
feats = self.fpn_bottleneck(fpn_outs)
return feats
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output