mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] The interface multiscale_output is defined but not used (#830)
* Add interface multiscale_output * Add space between args and their types * Fix default value
This commit is contained in:
parent
4ca42a3ee4
commit
aa438f5c95
@ -218,26 +218,41 @@ class HRModule(BaseModule):
|
|||||||
class HRNet(BaseModule):
|
class HRNet(BaseModule):
|
||||||
"""HRNet backbone.
|
"""HRNet backbone.
|
||||||
|
|
||||||
High-Resolution Representations for Labeling Pixels and Regions
|
`High-Resolution Representations for Labeling Pixels and Regions
|
||||||
arXiv: https://arxiv.org/abs/1904.04514
|
arXiv: <https://arxiv.org/abs/1904.04514>`_.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
extra (dict): detailed configuration for each stage of HRNet.
|
extra (dict): Detailed configuration for each stage of HRNet.
|
||||||
|
There must be 4 stages, the configuration for each stage must have
|
||||||
|
5 keys:
|
||||||
|
|
||||||
|
- num_modules (int): The number of HRModule in this stage.
|
||||||
|
- num_branches (int): The number of branches in the HRModule.
|
||||||
|
- block (str): The type of convolution block.
|
||||||
|
- num_blocks (tuple): The number of blocks in each branch.
|
||||||
|
The length must be equal to num_branches.
|
||||||
|
- num_channels (tuple): The number of channels in each branch.
|
||||||
|
The length must be equal to num_branches.
|
||||||
in_channels (int): Number of input image channels. Normally 3.
|
in_channels (int): Number of input image channels. Normally 3.
|
||||||
conv_cfg (dict): dictionary to construct and config conv layer.
|
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
Default: None.
|
||||||
|
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||||
|
Use `BN` by default.
|
||||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
and its variants only.
|
and its variants only. Default: False.
|
||||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
memory while slowing down the training speed.
|
memory while slowing down the training speed. Default: False.
|
||||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||||
-1 means not freezing any parameters. Default: -1.
|
-1 means not freezing any parameters. Default: -1.
|
||||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||||
in resblocks to let them behave as identity.
|
in resblocks to let them behave as identity. Default: False.
|
||||||
pretrained (str, optional): model pretrained path. Default: None
|
multiscale_output (bool): Whether to output multi-level features
|
||||||
|
produced by multiple branches. If False, only the first level
|
||||||
|
feature will be output. Default: True.
|
||||||
|
pretrained (str, optional): Model pretrained path. Default: None.
|
||||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
Default: None
|
Default: None.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from mmseg.models import HRNet
|
>>> from mmseg.models import HRNet
|
||||||
@ -290,6 +305,7 @@ class HRNet(BaseModule):
|
|||||||
with_cp=False,
|
with_cp=False,
|
||||||
frozen_stages=-1,
|
frozen_stages=-1,
|
||||||
zero_init_residual=False,
|
zero_init_residual=False,
|
||||||
|
multiscale_output=True,
|
||||||
pretrained=None,
|
pretrained=None,
|
||||||
init_cfg=None):
|
init_cfg=None):
|
||||||
super(HRNet, self).__init__(init_cfg)
|
super(HRNet, self).__init__(init_cfg)
|
||||||
@ -299,7 +315,7 @@ class HRNet(BaseModule):
|
|||||||
assert not (init_cfg and pretrained), \
|
assert not (init_cfg and pretrained), \
|
||||||
'init_cfg and pretrained cannot be setting at the same time'
|
'init_cfg and pretrained cannot be setting at the same time'
|
||||||
if isinstance(pretrained, str):
|
if isinstance(pretrained, str):
|
||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||||
'please use "init_cfg" instead')
|
'please use "init_cfg" instead')
|
||||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||||
elif pretrained is None:
|
elif pretrained is None:
|
||||||
@ -314,6 +330,16 @@ class HRNet(BaseModule):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('pretrained must be a str or None')
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
|
# Assert configurations of 4 stages are in extra
|
||||||
|
assert 'stage1' in extra and 'stage2' in extra \
|
||||||
|
and 'stage3' in extra and 'stage4' in extra
|
||||||
|
# Assert whether the length of `num_blocks` and `num_channels` are
|
||||||
|
# equal to `num_branches`
|
||||||
|
for i in range(4):
|
||||||
|
cfg = extra[f'stage{i + 1}']
|
||||||
|
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
|
||||||
|
len(cfg['num_channels']) == cfg['num_branches']
|
||||||
|
|
||||||
self.extra = extra
|
self.extra = extra
|
||||||
self.conv_cfg = conv_cfg
|
self.conv_cfg = conv_cfg
|
||||||
self.norm_cfg = norm_cfg
|
self.norm_cfg = norm_cfg
|
||||||
@ -391,7 +417,7 @@ class HRNet(BaseModule):
|
|||||||
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
||||||
num_channels)
|
num_channels)
|
||||||
self.stage4, pre_stage_channels = self._make_stage(
|
self.stage4, pre_stage_channels = self._make_stage(
|
||||||
self.stage4_cfg, num_channels)
|
self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
|
||||||
|
|
||||||
self._freeze_stages()
|
self._freeze_stages()
|
||||||
|
|
||||||
|
@ -1,12 +1,59 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from mmseg.models.backbones import HRNet
|
from mmseg.models.backbones.hrnet import HRModule, HRNet
|
||||||
|
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('block', [BasicBlock, Bottleneck])
|
||||||
|
def test_hrmodule(block):
|
||||||
|
# Test multiscale forward
|
||||||
|
num_channles = (32, 64)
|
||||||
|
in_channels = [c * block.expansion for c in num_channles]
|
||||||
|
hrmodule = HRModule(
|
||||||
|
num_branches=2,
|
||||||
|
blocks=block,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_blocks=(4, 4),
|
||||||
|
num_channels=num_channles,
|
||||||
|
)
|
||||||
|
|
||||||
|
feats = [
|
||||||
|
torch.randn(1, in_channels[0], 64, 64),
|
||||||
|
torch.randn(1, in_channels[1], 32, 32)
|
||||||
|
]
|
||||||
|
feats = hrmodule(feats)
|
||||||
|
|
||||||
|
assert len(feats) == 2
|
||||||
|
assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
|
||||||
|
assert feats[1].shape == torch.Size([1, in_channels[1], 32, 32])
|
||||||
|
|
||||||
|
# Test single scale forward
|
||||||
|
num_channles = (32, 64)
|
||||||
|
in_channels = [c * block.expansion for c in num_channles]
|
||||||
|
hrmodule = HRModule(
|
||||||
|
num_branches=2,
|
||||||
|
blocks=block,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_blocks=(4, 4),
|
||||||
|
num_channels=num_channles,
|
||||||
|
multiscale_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
feats = [
|
||||||
|
torch.randn(1, in_channels[0], 64, 64),
|
||||||
|
torch.randn(1, in_channels[1], 32, 32)
|
||||||
|
]
|
||||||
|
feats = hrmodule(feats)
|
||||||
|
|
||||||
|
assert len(feats) == 1
|
||||||
|
assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
|
||||||
|
|
||||||
|
|
||||||
def test_hrnet_backbone():
|
def test_hrnet_backbone():
|
||||||
# Test HRNET with two stage frozen
|
# only have 3 stages
|
||||||
|
|
||||||
extra = dict(
|
extra = dict(
|
||||||
stage1=dict(
|
stage1=dict(
|
||||||
num_modules=1,
|
num_modules=1,
|
||||||
@ -25,13 +72,46 @@ def test_hrnet_backbone():
|
|||||||
num_branches=3,
|
num_branches=3,
|
||||||
block='BASIC',
|
block='BASIC',
|
||||||
num_blocks=(4, 4, 4),
|
num_blocks=(4, 4, 4),
|
||||||
num_channels=(32, 64, 128)),
|
num_channels=(32, 64, 128)))
|
||||||
stage4=dict(
|
|
||||||
num_modules=3,
|
with pytest.raises(AssertionError):
|
||||||
num_branches=4,
|
# HRNet now only support 4 stages
|
||||||
block='BASIC',
|
HRNet(extra=extra)
|
||||||
num_blocks=(4, 4, 4, 4),
|
extra['stage4'] = dict(
|
||||||
num_channels=(32, 64, 128, 256)))
|
num_modules=3,
|
||||||
|
num_branches=3, # should be 4
|
||||||
|
block='BASIC',
|
||||||
|
num_blocks=(4, 4, 4, 4),
|
||||||
|
num_channels=(32, 64, 128, 256))
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# len(num_blocks) should equal num_branches
|
||||||
|
HRNet(extra=extra)
|
||||||
|
|
||||||
|
extra['stage4']['num_branches'] = 4
|
||||||
|
|
||||||
|
# Test hrnetv2p_w32
|
||||||
|
model = HRNet(extra=extra)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 256, 256)
|
||||||
|
feats = model(imgs)
|
||||||
|
assert len(feats) == 4
|
||||||
|
assert feats[0].shape == torch.Size([1, 32, 64, 64])
|
||||||
|
assert feats[3].shape == torch.Size([1, 256, 8, 8])
|
||||||
|
|
||||||
|
# Test single scale output
|
||||||
|
model = HRNet(extra=extra, multiscale_output=False)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 256, 256)
|
||||||
|
feats = model(imgs)
|
||||||
|
assert len(feats) == 1
|
||||||
|
assert feats[0].shape == torch.Size([1, 32, 64, 64])
|
||||||
|
|
||||||
|
# Test HRNET with two stage frozen
|
||||||
frozen_stages = 2
|
frozen_stages = 2
|
||||||
model = HRNet(extra, frozen_stages=frozen_stages)
|
model = HRNet(extra, frozen_stages=frozen_stages)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user