mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix features for resnetv2_50t
This commit is contained in:
parent
e8045e712f
commit
766b4d3262
@ -291,6 +291,10 @@ class ResNetStage(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def is_stem_deep(stem_type):
|
||||
return any([s in stem_type for s in ('deep', 'tiered')])
|
||||
|
||||
|
||||
def create_resnetv2_stem(
|
||||
in_chs, out_chs=64, stem_type='', preact=True,
|
||||
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
||||
@ -298,7 +302,7 @@ def create_resnetv2_stem(
|
||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
|
||||
|
||||
# NOTE conv padding mode can be changed by overriding the conv_layer def
|
||||
if any([s in stem_type for s in ('deep', 'tiered')]):
|
||||
if is_stem_deep(stem_type):
|
||||
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
||||
if 'tiered' in stem_type:
|
||||
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
|
||||
@ -350,7 +354,7 @@ class ResNetV2(nn.Module):
|
||||
stem_chs = make_div(stem_chs * wf)
|
||||
self.stem = create_resnetv2_stem(
|
||||
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
|
||||
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
|
||||
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
|
||||
|
||||
prev_chs = stem_chs
|
||||
|
Loading…
x
Reference in New Issue
Block a user