[Enhance] Add stochastic depth decay rule in resnet. (#1363)

* add stochastic depth decay rule to drop path rate

* add default value

* update

* pass ut

* update

* pass ut

* remove np
pull/1345/merge
Yixiao Fang 2023-02-22 11:04:28 +08:00 committed by GitHub
parent 8352951f3d
commit ab953f3209
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 3 deletions

View File

@ -143,6 +143,8 @@ class Res2Layer(Sequential):
Default: dict(type='BN')
scales (int): Scales used in Res2Net. Default: 4
base_width (int): Basic width of each scale. Default: 26
drop_path_rate (float or np.ndarray): stochastic depth rate.
Default: 0.
"""
def __init__(self,
@ -156,9 +158,16 @@ class Res2Layer(Sequential):
norm_cfg=dict(type='BN'),
scales=4,
base_width=26,
drop_path_rate=0.0,
**kwargs):
self.block = block
if isinstance(drop_path_rate, float):
drop_path_rate = [drop_path_rate] * num_blocks
assert len(drop_path_rate
) == num_blocks, 'Please check the length of drop_path_rate'
downsample = None
if stride != 1 or in_channels != out_channels:
if avg_down:
@ -201,9 +210,10 @@ class Res2Layer(Sequential):
scales=scales,
base_width=base_width,
stage_type='stage',
drop_path_rate=drop_path_rate[0],
**kwargs))
in_channels = out_channels
for _ in range(1, num_blocks):
for i in range(1, num_blocks):
layers.append(
block(
in_channels=in_channels,
@ -213,6 +223,7 @@ class Res2Layer(Sequential):
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
drop_path_rate=drop_path_rate[i],
**kwargs))
super(Res2Layer, self).__init__(*layers)

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
@ -334,6 +334,8 @@ class ResLayer(nn.Sequential):
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
drop_path_rate (float or list): stochastic depth rate.
Default: 0.
"""
def __init__(self,
@ -346,10 +348,17 @@ class ResLayer(nn.Sequential):
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
drop_path_rate=0.0,
**kwargs):
self.block = block
self.expansion = get_expansion(block, expansion)
if isinstance(drop_path_rate, float):
drop_path_rate = [drop_path_rate] * num_blocks
assert len(drop_path_rate
) == num_blocks, 'Please check the length of drop_path_rate'
downsample = None
if stride != 1 or in_channels != out_channels:
downsample = []
@ -384,6 +393,7 @@ class ResLayer(nn.Sequential):
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate[0],
**kwargs))
in_channels = out_channels
for i in range(1, num_blocks):
@ -395,6 +405,7 @@ class ResLayer(nn.Sequential):
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate[i],
**kwargs))
super(ResLayer, self).__init__(*layers)
@ -518,6 +529,16 @@ class ResNet(BaseBackbone):
self.res_layers = []
_in_channels = stem_channels
_out_channels = base_channels * self.expansion
# stochastic depth decay rule
total_depth = sum(stage_blocks)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
]
# net_num_blocks = sum(stage_blocks)
# dpr = np.linspace(0, drop_path_rate, net_num_blocks)
# block_id = 0
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
@ -534,9 +555,10 @@ class ResNet(BaseBackbone):
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate)
drop_path_rate=dpr[:num_blocks])
_in_channels = _out_channels
_out_channels *= 2
dpr = dpr[num_blocks:]
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)