mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix ResNet based models to work w/ norm layers w/o affine params. Reformat long arg lists into vertical form.
This commit is contained in:
parent
d5aa17e415
commit
6902c48a5f
@ -962,9 +962,21 @@ class BasicBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
|
self,
|
||||||
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
|
in_chs,
|
||||||
drop_path_rate=0.):
|
out_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=(1, 1),
|
||||||
|
group_size=None,
|
||||||
|
bottle_ratio=1.0,
|
||||||
|
downsample='avg',
|
||||||
|
attn_last=True,
|
||||||
|
linear_out=False,
|
||||||
|
layers: LayerFn = None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
layers = layers or LayerFn()
|
layers = layers or LayerFn()
|
||||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
@ -983,7 +995,7 @@ class BasicBlock(nn.Module):
|
|||||||
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last: bool = False):
|
def init_weights(self, zero_init_last: bool = False):
|
||||||
if zero_init_last and self.shortcut is not None:
|
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
|
||||||
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
||||||
for attn in (self.attn, self.attn_last):
|
for attn in (self.attn, self.attn_last):
|
||||||
if hasattr(attn, 'reset_parameters'):
|
if hasattr(attn, 'reset_parameters'):
|
||||||
@ -1005,9 +1017,23 @@ class BottleneckBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
self,
|
||||||
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
|
in_chs,
|
||||||
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
out_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=(1, 1),
|
||||||
|
bottle_ratio=1.,
|
||||||
|
group_size=None,
|
||||||
|
downsample='avg',
|
||||||
|
attn_last=False,
|
||||||
|
linear_out=False,
|
||||||
|
extra_conv=False,
|
||||||
|
bottle_in=False,
|
||||||
|
layers: LayerFn = None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(BottleneckBlock, self).__init__()
|
super(BottleneckBlock, self).__init__()
|
||||||
layers = layers or LayerFn()
|
layers = layers or LayerFn()
|
||||||
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
||||||
@ -1031,7 +1057,7 @@ class BottleneckBlock(nn.Module):
|
|||||||
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last: bool = False):
|
def init_weights(self, zero_init_last: bool = False):
|
||||||
if zero_init_last and self.shortcut is not None:
|
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
|
||||||
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||||
for attn in (self.attn, self.attn_last):
|
for attn in (self.attn, self.attn_last):
|
||||||
if hasattr(attn, 'reset_parameters'):
|
if hasattr(attn, 'reset_parameters'):
|
||||||
@ -1063,9 +1089,21 @@ class DarkBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
self,
|
||||||
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
|
in_chs,
|
||||||
drop_path_rate=0.):
|
out_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=(1, 1),
|
||||||
|
bottle_ratio=1.0,
|
||||||
|
group_size=None,
|
||||||
|
downsample='avg',
|
||||||
|
attn_last=True,
|
||||||
|
linear_out=False,
|
||||||
|
layers: LayerFn = None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(DarkBlock, self).__init__()
|
super(DarkBlock, self).__init__()
|
||||||
layers = layers or LayerFn()
|
layers = layers or LayerFn()
|
||||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
@ -1085,7 +1123,7 @@ class DarkBlock(nn.Module):
|
|||||||
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last: bool = False):
|
def init_weights(self, zero_init_last: bool = False):
|
||||||
if zero_init_last and self.shortcut is not None:
|
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
|
||||||
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
||||||
for attn in (self.attn, self.attn_last):
|
for attn in (self.attn, self.attn_last):
|
||||||
if hasattr(attn, 'reset_parameters'):
|
if hasattr(attn, 'reset_parameters'):
|
||||||
@ -1114,9 +1152,21 @@ class EdgeBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
self,
|
||||||
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
|
in_chs,
|
||||||
drop_block=None, drop_path_rate=0.):
|
out_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=(1, 1),
|
||||||
|
bottle_ratio=1.0,
|
||||||
|
group_size=None,
|
||||||
|
downsample='avg',
|
||||||
|
attn_last=False,
|
||||||
|
linear_out=False,
|
||||||
|
layers: LayerFn = None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(EdgeBlock, self).__init__()
|
super(EdgeBlock, self).__init__()
|
||||||
layers = layers or LayerFn()
|
layers = layers or LayerFn()
|
||||||
mid_chs = make_divisible(out_chs * bottle_ratio)
|
mid_chs = make_divisible(out_chs * bottle_ratio)
|
||||||
@ -1135,7 +1185,7 @@ class EdgeBlock(nn.Module):
|
|||||||
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last: bool = False):
|
def init_weights(self, zero_init_last: bool = False):
|
||||||
if zero_init_last and self.shortcut is not None:
|
if zero_init_last and self.shortcut is not None and getattr(self.conv2_1x1.bn, 'weight', None) is not None:
|
||||||
nn.init.zeros_(self.conv2_1x1.bn.weight)
|
nn.init.zeros_(self.conv2_1x1.bn.weight)
|
||||||
for attn in (self.attn, self.attn_last):
|
for attn in (self.attn, self.attn_last):
|
||||||
if hasattr(attn, 'reset_parameters'):
|
if hasattr(attn, 'reset_parameters'):
|
||||||
@ -1162,8 +1212,19 @@ class RepVggBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
|
self,
|
||||||
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=(1, 1),
|
||||||
|
bottle_ratio=1.0,
|
||||||
|
group_size=None,
|
||||||
|
downsample='',
|
||||||
|
layers: LayerFn = None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(RepVggBlock, self).__init__()
|
super(RepVggBlock, self).__init__()
|
||||||
layers = layers or LayerFn()
|
layers = layers or LayerFn()
|
||||||
groups = num_groups(group_size, in_chs)
|
groups = num_groups(group_size, in_chs)
|
||||||
@ -1204,9 +1265,24 @@ class SelfAttnBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
|
self,
|
||||||
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
|
in_chs,
|
||||||
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
|
out_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=(1, 1),
|
||||||
|
bottle_ratio=1.,
|
||||||
|
group_size=None,
|
||||||
|
downsample='avg',
|
||||||
|
extra_conv=False,
|
||||||
|
linear_out=False,
|
||||||
|
bottle_in=False,
|
||||||
|
post_attn_na=True,
|
||||||
|
feat_size=None,
|
||||||
|
layers: LayerFn = None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(SelfAttnBlock, self).__init__()
|
super(SelfAttnBlock, self).__init__()
|
||||||
assert layers is not None
|
assert layers is not None
|
||||||
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
||||||
@ -1233,7 +1309,7 @@ class SelfAttnBlock(nn.Module):
|
|||||||
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
|
|
||||||
def init_weights(self, zero_init_last: bool = False):
|
def init_weights(self, zero_init_last: bool = False):
|
||||||
if zero_init_last and self.shortcut is not None:
|
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
|
||||||
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||||
if hasattr(self.self_attn, 'reset_parameters'):
|
if hasattr(self.self_attn, 'reset_parameters'):
|
||||||
self.self_attn.reset_parameters()
|
self.self_attn.reset_parameters()
|
||||||
@ -1274,8 +1350,17 @@ def create_block(block: Union[str, nn.Module], **kwargs):
|
|||||||
class Stem(nn.Sequential):
|
class Stem(nn.Sequential):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
|
self,
|
||||||
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=4,
|
||||||
|
pool='maxpool',
|
||||||
|
num_rep=3,
|
||||||
|
num_act=None,
|
||||||
|
chs_decay=0.5,
|
||||||
|
layers: LayerFn = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert stride in (2, 4)
|
assert stride in (2, 4)
|
||||||
layers = layers or LayerFn()
|
layers = layers or LayerFn()
|
||||||
@ -1319,7 +1404,14 @@ class Stem(nn.Sequential):
|
|||||||
assert curr_stride == stride
|
assert curr_stride == stride
|
||||||
|
|
||||||
|
|
||||||
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
|
def create_byob_stem(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
stem_type='',
|
||||||
|
pool_type='',
|
||||||
|
feat_prefix='stem',
|
||||||
|
layers: LayerFn = None,
|
||||||
|
):
|
||||||
layers = layers or LayerFn()
|
layers = layers or LayerFn()
|
||||||
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
|
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
|
||||||
if 'quad' in stem_type:
|
if 'quad' in stem_type:
|
||||||
@ -1407,10 +1499,14 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
|
|||||||
|
|
||||||
|
|
||||||
def create_byob_stages(
|
def create_byob_stages(
|
||||||
cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any],
|
cfg: ByoModelCfg,
|
||||||
|
drop_path_rate: float,
|
||||||
|
output_stride: int,
|
||||||
|
stem_feat: Dict[str, Any],
|
||||||
feat_size: Optional[int] = None,
|
feat_size: Optional[int] = None,
|
||||||
layers: Optional[LayerFn] = None,
|
layers: Optional[LayerFn] = None,
|
||||||
block_kwargs_fn: Optional[Callable] = update_block_kwargs):
|
block_kwargs_fn: Optional[Callable] = update_block_kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
layers = layers or LayerFn()
|
layers = layers or LayerFn()
|
||||||
feature_info = []
|
feature_info = []
|
||||||
@ -1485,8 +1581,17 @@ class ByobNet(nn.Module):
|
|||||||
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
|
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
self,
|
||||||
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
|
cfg: ByoModelCfg,
|
||||||
|
num_classes=1000,
|
||||||
|
in_chans=3,
|
||||||
|
global_pool='avg',
|
||||||
|
output_stride=32,
|
||||||
|
zero_init_last=True,
|
||||||
|
img_size=None,
|
||||||
|
drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
|
@ -51,9 +51,21 @@ class Bottle2neck(nn.Module):
|
|||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, inplanes, planes, stride=1, downsample=None,
|
self,
|
||||||
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
|
inplanes,
|
||||||
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
|
planes,
|
||||||
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
cardinality=1,
|
||||||
|
base_width=26,
|
||||||
|
scale=4,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=None,
|
||||||
|
attn_layer=None,
|
||||||
|
**_,
|
||||||
|
):
|
||||||
super(Bottle2neck, self).__init__()
|
super(Bottle2neck, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.is_first = stride > 1 or downsample is not None
|
self.is_first = stride > 1 or downsample is not None
|
||||||
@ -89,7 +101,8 @@ class Bottle2neck(nn.Module):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
|
|
||||||
def zero_init_last(self):
|
def zero_init_last(self):
|
||||||
nn.init.zeros_(self.bn3.weight)
|
if getattr(self.bn3, 'weight', None) is not None:
|
||||||
|
nn.init.zeros_(self.bn3.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
|
@ -57,10 +57,27 @@ class ResNestBottleneck(nn.Module):
|
|||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, inplanes, planes, stride=1, downsample=None,
|
self,
|
||||||
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
|
inplanes,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
planes,
|
||||||
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
radix=1,
|
||||||
|
cardinality=1,
|
||||||
|
base_width=64,
|
||||||
|
avd=False,
|
||||||
|
avd_first=False,
|
||||||
|
is_first=False,
|
||||||
|
reduce_first=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
attn_layer=None,
|
||||||
|
aa_layer=None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path=None,
|
||||||
|
):
|
||||||
super(ResNestBottleneck, self).__init__()
|
super(ResNestBottleneck, self).__init__()
|
||||||
assert reduce_first == 1 # not supported
|
assert reduce_first == 1 # not supported
|
||||||
assert attn_layer is None # not supported
|
assert attn_layer is None # not supported
|
||||||
@ -103,7 +120,8 @@ class ResNestBottleneck(nn.Module):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
|
|
||||||
def zero_init_last(self):
|
def zero_init_last(self):
|
||||||
nn.init.zeros_(self.bn3.weight)
|
if getattr(self.bn3, 'weight', None) is not None:
|
||||||
|
nn.init.zeros_(self.bn3.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
|
@ -337,9 +337,23 @@ class BasicBlock(nn.Module):
|
|||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
self,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
inplanes,
|
||||||
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
planes,
|
||||||
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
cardinality=1,
|
||||||
|
base_width=64,
|
||||||
|
reduce_first=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
attn_layer=None,
|
||||||
|
aa_layer=None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path=None,
|
||||||
|
):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
|
|
||||||
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
||||||
@ -370,7 +384,8 @@ class BasicBlock(nn.Module):
|
|||||||
self.drop_path = drop_path
|
self.drop_path = drop_path
|
||||||
|
|
||||||
def zero_init_last(self):
|
def zero_init_last(self):
|
||||||
nn.init.zeros_(self.bn2.weight)
|
if getattr(self.bn2, 'weight', None) is not None:
|
||||||
|
nn.init.zeros_(self.bn2.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
@ -402,9 +417,23 @@ class Bottleneck(nn.Module):
|
|||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
self,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
inplanes,
|
||||||
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
planes,
|
||||||
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
cardinality=1,
|
||||||
|
base_width=64,
|
||||||
|
reduce_first=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
attn_layer=None,
|
||||||
|
aa_layer=None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path=None,
|
||||||
|
):
|
||||||
super(Bottleneck, self).__init__()
|
super(Bottleneck, self).__init__()
|
||||||
|
|
||||||
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
||||||
@ -437,7 +466,8 @@ class Bottleneck(nn.Module):
|
|||||||
self.drop_path = drop_path
|
self.drop_path = drop_path
|
||||||
|
|
||||||
def zero_init_last(self):
|
def zero_init_last(self):
|
||||||
nn.init.zeros_(self.bn3.weight)
|
if getattr(self.bn3, 'weight', None) is not None:
|
||||||
|
nn.init.zeros_(self.bn3.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
@ -508,8 +538,18 @@ def drop_blocks(drop_prob=0.):
|
|||||||
|
|
||||||
|
|
||||||
def make_blocks(
|
def make_blocks(
|
||||||
block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
|
block_fn,
|
||||||
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
|
channels,
|
||||||
|
block_repeats,
|
||||||
|
inplanes,
|
||||||
|
reduce_first=1,
|
||||||
|
output_stride=32,
|
||||||
|
down_kernel_size=1,
|
||||||
|
avg_down=False,
|
||||||
|
drop_block_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
stages = []
|
stages = []
|
||||||
feature_info = []
|
feature_info = []
|
||||||
net_num_blocks = sum(block_repeats)
|
net_num_blocks = sum(block_repeats)
|
||||||
@ -528,8 +568,14 @@ def make_blocks(
|
|||||||
downsample = None
|
downsample = None
|
||||||
if stride != 1 or inplanes != planes * block_fn.expansion:
|
if stride != 1 or inplanes != planes * block_fn.expansion:
|
||||||
down_kwargs = dict(
|
down_kwargs = dict(
|
||||||
in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size,
|
in_channels=inplanes,
|
||||||
stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer'))
|
out_channels=planes * block_fn.expansion,
|
||||||
|
kernel_size=down_kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
first_dilation=prev_dilation,
|
||||||
|
norm_layer=kwargs.get('norm_layer'),
|
||||||
|
)
|
||||||
downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
|
downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
|
||||||
|
|
||||||
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
|
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
|
||||||
@ -609,10 +655,30 @@ class ResNet(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
|
self,
|
||||||
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1,
|
block,
|
||||||
down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
|
layers,
|
||||||
drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None):
|
num_classes=1000,
|
||||||
|
in_chans=3,
|
||||||
|
output_stride=32,
|
||||||
|
global_pool='avg',
|
||||||
|
cardinality=1,
|
||||||
|
base_width=64,
|
||||||
|
stem_width=64,
|
||||||
|
stem_type='',
|
||||||
|
replace_stem_pool=False,
|
||||||
|
block_reduce_first=1,
|
||||||
|
down_kernel_size=1,
|
||||||
|
avg_down=False,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
aa_layer=None,
|
||||||
|
drop_rate=0.0,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
drop_block_rate=0.,
|
||||||
|
zero_init_last=True,
|
||||||
|
block_args=None,
|
||||||
|
):
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
block_args = block_args or dict()
|
block_args = block_args or dict()
|
||||||
assert output_stride in (8, 16, 32)
|
assert output_stride in (8, 16, 32)
|
||||||
@ -663,10 +729,23 @@ class ResNet(nn.Module):
|
|||||||
# Feature Blocks
|
# Feature Blocks
|
||||||
channels = [64, 128, 256, 512]
|
channels = [64, 128, 256, 512]
|
||||||
stage_modules, stage_feature_info = make_blocks(
|
stage_modules, stage_feature_info = make_blocks(
|
||||||
block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
|
block,
|
||||||
output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
|
channels,
|
||||||
down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
|
layers,
|
||||||
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
|
inplanes,
|
||||||
|
cardinality=cardinality,
|
||||||
|
base_width=base_width,
|
||||||
|
output_stride=output_stride,
|
||||||
|
reduce_first=block_reduce_first,
|
||||||
|
avg_down=avg_down,
|
||||||
|
down_kernel_size=down_kernel_size,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
aa_layer=aa_layer,
|
||||||
|
drop_block_rate=drop_block_rate,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
**block_args,
|
||||||
|
)
|
||||||
for stage in stage_modules:
|
for stage in stage_modules:
|
||||||
self.add_module(*stage) # layer1, layer2, etc
|
self.add_module(*stage) # layer1, layer2, etc
|
||||||
self.feature_info.extend(stage_feature_info)
|
self.feature_info.extend(stage_feature_info)
|
||||||
@ -687,9 +766,6 @@ class ResNet(nn.Module):
|
|||||||
for n, m in self.named_modules():
|
for n, m in self.named_modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
|
||||||
nn.init.ones_(m.weight)
|
|
||||||
nn.init.zeros_(m.bias)
|
|
||||||
if zero_init_last:
|
if zero_init_last:
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if hasattr(m, 'zero_init_last'):
|
if hasattr(m, 'zero_init_last'):
|
||||||
|
@ -155,8 +155,20 @@ class PreActBottleneck(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
|
self,
|
||||||
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
|
in_chs,
|
||||||
|
out_chs=None,
|
||||||
|
bottle_ratio=0.25,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
groups=1,
|
||||||
|
act_layer=None,
|
||||||
|
conv_layer=None,
|
||||||
|
norm_layer=None,
|
||||||
|
proj_layer=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
conv_layer = conv_layer or StdConv2d
|
conv_layer = conv_layer or StdConv2d
|
||||||
@ -202,8 +214,20 @@ class Bottleneck(nn.Module):
|
|||||||
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
|
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
|
self,
|
||||||
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
|
in_chs,
|
||||||
|
out_chs=None,
|
||||||
|
bottle_ratio=0.25,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
groups=1,
|
||||||
|
act_layer=None,
|
||||||
|
conv_layer=None,
|
||||||
|
norm_layer=None,
|
||||||
|
proj_layer=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
act_layer = act_layer or nn.ReLU
|
act_layer = act_layer or nn.ReLU
|
||||||
@ -229,7 +253,8 @@ class Bottleneck(nn.Module):
|
|||||||
self.act3 = act_layer(inplace=True)
|
self.act3 = act_layer(inplace=True)
|
||||||
|
|
||||||
def zero_init_last(self):
|
def zero_init_last(self):
|
||||||
nn.init.zeros_(self.norm3.weight)
|
if getattr(self.norm3, 'weight', None) is not None:
|
||||||
|
nn.init.zeros_(self.norm3.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# shortcut branch
|
# shortcut branch
|
||||||
@ -283,9 +308,22 @@ class DownsampleAvg(nn.Module):
|
|||||||
class ResNetStage(nn.Module):
|
class ResNetStage(nn.Module):
|
||||||
"""ResNet Stage."""
|
"""ResNet Stage."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
|
self,
|
||||||
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
|
in_chs,
|
||||||
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
|
out_chs,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
depth,
|
||||||
|
bottle_ratio=0.25,
|
||||||
|
groups=1,
|
||||||
|
avg_down=False,
|
||||||
|
block_dpr=None,
|
||||||
|
block_fn=PreActBottleneck,
|
||||||
|
act_layer=None,
|
||||||
|
conv_layer=None,
|
||||||
|
norm_layer=None,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
super(ResNetStage, self).__init__()
|
super(ResNetStage, self).__init__()
|
||||||
first_dilation = 1 if dilation in (1, 2) else 2
|
first_dilation = 1 if dilation in (1, 2) else 2
|
||||||
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
|
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||||
@ -313,8 +351,13 @@ def is_stem_deep(stem_type):
|
|||||||
|
|
||||||
|
|
||||||
def create_resnetv2_stem(
|
def create_resnetv2_stem(
|
||||||
in_chs, out_chs=64, stem_type='', preact=True,
|
in_chs,
|
||||||
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
out_chs=64,
|
||||||
|
stem_type='',
|
||||||
|
preact=True,
|
||||||
|
conv_layer=StdConv2d,
|
||||||
|
norm_layer=partial(GroupNormAct, num_groups=32),
|
||||||
|
):
|
||||||
stem = OrderedDict()
|
stem = OrderedDict()
|
||||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
|
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
|
||||||
|
|
||||||
@ -357,11 +400,25 @@ class ResNetV2(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, layers, channels=(256, 512, 1024, 2048),
|
self,
|
||||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
layers,
|
||||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
channels=(256, 512, 1024, 2048),
|
||||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
num_classes=1000,
|
||||||
drop_rate=0., drop_path_rate=0., zero_init_last=False):
|
in_chans=3,
|
||||||
|
global_pool='avg',
|
||||||
|
output_stride=32,
|
||||||
|
width_factor=1,
|
||||||
|
stem_chs=64,
|
||||||
|
stem_type='',
|
||||||
|
avg_down=False,
|
||||||
|
preact=True,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
conv_layer=StdConv2d,
|
||||||
|
norm_layer=partial(GroupNormAct, num_groups=32),
|
||||||
|
drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
zero_init_last=False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
@ -387,8 +444,18 @@ class ResNetV2(nn.Module):
|
|||||||
dilation *= stride
|
dilation *= stride
|
||||||
stride = 1
|
stride = 1
|
||||||
stage = ResNetStage(
|
stage = ResNetStage(
|
||||||
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
prev_chs,
|
||||||
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
|
out_chs,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
depth=d,
|
||||||
|
avg_down=avg_down,
|
||||||
|
act_layer=act_layer,
|
||||||
|
conv_layer=conv_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
block_dpr=bdpr,
|
||||||
|
block_fn=block_fn,
|
||||||
|
)
|
||||||
prev_chs = out_chs
|
prev_chs = out_chs
|
||||||
curr_stride *= stride
|
curr_stride *= stride
|
||||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
|
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
|
||||||
|
@ -47,9 +47,24 @@ class SelectiveKernelBasic(nn.Module):
|
|||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
|
self,
|
||||||
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU,
|
inplanes,
|
||||||
norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
planes,
|
||||||
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
cardinality=1,
|
||||||
|
base_width=64,
|
||||||
|
sk_kwargs=None,
|
||||||
|
reduce_first=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
attn_layer=None,
|
||||||
|
aa_layer=None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path=None,
|
||||||
|
):
|
||||||
super(SelectiveKernelBasic, self).__init__()
|
super(SelectiveKernelBasic, self).__init__()
|
||||||
|
|
||||||
sk_kwargs = sk_kwargs or {}
|
sk_kwargs = sk_kwargs or {}
|
||||||
@ -71,7 +86,8 @@ class SelectiveKernelBasic(nn.Module):
|
|||||||
self.drop_path = drop_path
|
self.drop_path = drop_path
|
||||||
|
|
||||||
def zero_init_last(self):
|
def zero_init_last(self):
|
||||||
nn.init.zeros_(self.conv2.bn.weight)
|
if getattr(self.conv2.bn, 'weight', None) is not None:
|
||||||
|
nn.init.zeros_(self.conv2.bn.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
@ -92,9 +108,24 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, sk_kwargs=None,
|
self,
|
||||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
inplanes,
|
||||||
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
planes,
|
||||||
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
cardinality=1,
|
||||||
|
base_width=64,
|
||||||
|
sk_kwargs=None,
|
||||||
|
reduce_first=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
attn_layer=None,
|
||||||
|
aa_layer=None,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path=None,
|
||||||
|
):
|
||||||
super(SelectiveKernelBottleneck, self).__init__()
|
super(SelectiveKernelBottleneck, self).__init__()
|
||||||
|
|
||||||
sk_kwargs = sk_kwargs or {}
|
sk_kwargs = sk_kwargs or {}
|
||||||
@ -115,7 +146,8 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||||||
self.drop_path = drop_path
|
self.drop_path = drop_path
|
||||||
|
|
||||||
def zero_init_last(self):
|
def zero_init_last(self):
|
||||||
nn.init.zeros_(self.conv3.bn.weight)
|
if getattr(self.conv3.bn, 'weight', None) is not None:
|
||||||
|
nn.init.zeros_(self.conv3.bn.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user