[Enhance] Add stochastic depth decay rule in resnet. (#1363)
* add stochastic depth decay rule to drop path rate * add default value * update * pass ut * update * pass ut * remove nppull/1345/merge
parent
8352951f3d
commit
ab953f3209
|
@ -143,6 +143,8 @@ class Res2Layer(Sequential):
|
|||
Default: dict(type='BN')
|
||||
scales (int): Scales used in Res2Net. Default: 4
|
||||
base_width (int): Basic width of each scale. Default: 26
|
||||
drop_path_rate (float or np.ndarray): stochastic depth rate.
|
||||
Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -156,9 +158,16 @@ class Res2Layer(Sequential):
|
|||
norm_cfg=dict(type='BN'),
|
||||
scales=4,
|
||||
base_width=26,
|
||||
drop_path_rate=0.0,
|
||||
**kwargs):
|
||||
self.block = block
|
||||
|
||||
if isinstance(drop_path_rate, float):
|
||||
drop_path_rate = [drop_path_rate] * num_blocks
|
||||
|
||||
assert len(drop_path_rate
|
||||
) == num_blocks, 'Please check the length of drop_path_rate'
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
if avg_down:
|
||||
|
@ -201,9 +210,10 @@ class Res2Layer(Sequential):
|
|||
scales=scales,
|
||||
base_width=base_width,
|
||||
stage_type='stage',
|
||||
drop_path_rate=drop_path_rate[0],
|
||||
**kwargs))
|
||||
in_channels = out_channels
|
||||
for _ in range(1, num_blocks):
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels=in_channels,
|
||||
|
@ -213,6 +223,7 @@ class Res2Layer(Sequential):
|
|||
norm_cfg=norm_cfg,
|
||||
scales=scales,
|
||||
base_width=base_width,
|
||||
drop_path_rate=drop_path_rate[i],
|
||||
**kwargs))
|
||||
super(Res2Layer, self).__init__(*layers)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
|
||||
|
@ -334,6 +334,8 @@ class ResLayer(nn.Sequential):
|
|||
layer. Default: None
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN')
|
||||
drop_path_rate (float or list): stochastic depth rate.
|
||||
Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -346,10 +348,17 @@ class ResLayer(nn.Sequential):
|
|||
avg_down=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
drop_path_rate=0.0,
|
||||
**kwargs):
|
||||
self.block = block
|
||||
self.expansion = get_expansion(block, expansion)
|
||||
|
||||
if isinstance(drop_path_rate, float):
|
||||
drop_path_rate = [drop_path_rate] * num_blocks
|
||||
|
||||
assert len(drop_path_rate
|
||||
) == num_blocks, 'Please check the length of drop_path_rate'
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
downsample = []
|
||||
|
@ -384,6 +393,7 @@ class ResLayer(nn.Sequential):
|
|||
downsample=downsample,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
drop_path_rate=drop_path_rate[0],
|
||||
**kwargs))
|
||||
in_channels = out_channels
|
||||
for i in range(1, num_blocks):
|
||||
|
@ -395,6 +405,7 @@ class ResLayer(nn.Sequential):
|
|||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
drop_path_rate=drop_path_rate[i],
|
||||
**kwargs))
|
||||
super(ResLayer, self).__init__(*layers)
|
||||
|
||||
|
@ -518,6 +529,16 @@ class ResNet(BaseBackbone):
|
|||
self.res_layers = []
|
||||
_in_channels = stem_channels
|
||||
_out_channels = base_channels * self.expansion
|
||||
|
||||
# stochastic depth decay rule
|
||||
total_depth = sum(stage_blocks)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
]
|
||||
# net_num_blocks = sum(stage_blocks)
|
||||
# dpr = np.linspace(0, drop_path_rate, net_num_blocks)
|
||||
# block_id = 0
|
||||
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
stride = strides[i]
|
||||
dilation = dilations[i]
|
||||
|
@ -534,9 +555,10 @@ class ResNet(BaseBackbone):
|
|||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
drop_path_rate=drop_path_rate)
|
||||
drop_path_rate=dpr[:num_blocks])
|
||||
_in_channels = _out_channels
|
||||
_out_channels *= 2
|
||||
dpr = dpr[num_blocks:]
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
|
Loading…
Reference in New Issue