mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add DropPath (stochastic depth) to RegNet
This commit is contained in:
parent
47794d2c59
commit
6890300877
@ -18,7 +18,7 @@ import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule
|
||||
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
@ -195,7 +195,7 @@ class RegStage(nn.Module):
|
||||
"""Stage (sequence of blocks w/ the same output shape)."""
|
||||
|
||||
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
|
||||
block_fn=Bottleneck, se_ratio=0.):
|
||||
block_fn=Bottleneck, se_ratio=0., drop_path_rate=None, drop_block=None):
|
||||
super(RegStage, self).__init__()
|
||||
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
@ -203,6 +203,7 @@ class RegStage(nn.Module):
|
||||
block_stride = stride if i == 0 else 1
|
||||
block_in_chs = in_chs if i == 0 else out_chs
|
||||
block_dilation = first_dilation if i == 0 else dilation
|
||||
drop_path = DropPath(drop_path_rate[i]) if drop_path_rate is not None else None
|
||||
if (block_in_chs != out_chs) or (block_stride != 1):
|
||||
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
|
||||
else:
|
||||
@ -212,7 +213,7 @@ class RegStage(nn.Module):
|
||||
self.add_module(
|
||||
name, block_fn(
|
||||
block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio,
|
||||
downsample=proj_block, **block_kwargs)
|
||||
downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -229,7 +230,7 @@ class RegNet(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
|
||||
zero_init_last_bn=True):
|
||||
drop_path_rate=0., zero_init_last_bn=True):
|
||||
super().__init__()
|
||||
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
|
||||
self.num_classes = num_classes
|
||||
@ -244,7 +245,7 @@ class RegNet(nn.Module):
|
||||
# Construct the stages
|
||||
prev_width = stem_width
|
||||
curr_stride = 2
|
||||
stage_params = self._get_stage_params(cfg, output_stride=output_stride)
|
||||
stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
|
||||
se_ratio = cfg['se_ratio']
|
||||
for i, stage_args in enumerate(stage_params):
|
||||
stage_name = "s{}".format(i + 1)
|
||||
@ -272,7 +273,7 @@ class RegNet(nn.Module):
|
||||
if hasattr(m, 'zero_init_last_bn'):
|
||||
m.zero_init_last_bn()
|
||||
|
||||
def _get_stage_params(self, cfg, default_stride=2, output_stride=32):
|
||||
def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.):
|
||||
# Generate RegNet ws per block
|
||||
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
|
||||
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
||||
@ -285,24 +286,26 @@ class RegNet(nn.Module):
|
||||
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
|
||||
stage_strides = []
|
||||
stage_dilations = []
|
||||
total_stride = 2
|
||||
net_stride = 2
|
||||
dilation = 1
|
||||
for _ in range(num_stages):
|
||||
if total_stride >= output_stride:
|
||||
if net_stride >= output_stride:
|
||||
dilation *= default_stride
|
||||
stride = 1
|
||||
else:
|
||||
stride = default_stride
|
||||
total_stride *= stride
|
||||
net_stride *= stride
|
||||
stage_strides.append(stride)
|
||||
stage_dilations.append(dilation)
|
||||
stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1]))
|
||||
|
||||
# Adjust the compatibility of ws and gws
|
||||
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
|
||||
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width']
|
||||
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rate']
|
||||
stage_params = [
|
||||
dict(zip(param_names, params)) for params in
|
||||
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups)]
|
||||
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
|
||||
stage_dpr)]
|
||||
return stage_params
|
||||
|
||||
def get_classifier(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user