mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix cls head in hgnet
This commit is contained in:
parent
56ae8b906d
commit
9dbea3bef6
@ -423,6 +423,7 @@ class PPHGNet(nn.Module):
|
|||||||
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
|
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
|
||||||
self.stages = nn.Sequential(*stages)
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
|
if num_classes > 0:
|
||||||
self.head = ClassifierHead(
|
self.head = ClassifierHead(
|
||||||
num_features=self.num_features,
|
num_features=self.num_features,
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
@ -432,6 +433,11 @@ class PPHGNet(nn.Module):
|
|||||||
class_expand=class_expand,
|
class_expand=class_expand,
|
||||||
use_lab=use_lab
|
use_lab=use_lab
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if global_pool == 'avg':
|
||||||
|
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||||
|
else:
|
||||||
|
self.head = nn.Identity()
|
||||||
|
|
||||||
for n, m in self.named_modules():
|
for n, m in self.named_modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
@ -608,65 +614,65 @@ def _cfg(url='', **kwargs):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
# default_cfgs = generate_default_cfgs({
|
||||||
'hgnet_tiny.paddle_in1k': _cfg(
|
# 'hgnet_tiny.paddle_in1k': _cfg(
|
||||||
first_conv='stem.0.conv',
|
# first_conv='stem.0.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnet_tiny.ssld_in1k': _cfg(
|
# 'hgnet_tiny.ssld_in1k': _cfg(
|
||||||
first_conv='stem.0.conv',
|
# first_conv='stem.0.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnet_small.paddle_in1k': _cfg(
|
# 'hgnet_small.paddle_in1k': _cfg(
|
||||||
first_conv='stem.0.conv',
|
# first_conv='stem.0.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnet_small.ssld_in1k': _cfg(
|
# 'hgnet_small.ssld_in1k': _cfg(
|
||||||
first_conv='stem.0.conv',
|
# first_conv='stem.0.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnet_base.ssld_in1k': _cfg(
|
# 'hgnet_base.ssld_in1k': _cfg(
|
||||||
first_conv='stem.0.conv',
|
# first_conv='stem.0.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b0.ssld_in1k': _cfg(
|
# 'hgnetv2_b0.ssld_in1k': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b0.ssld_stage1': _cfg(
|
# 'hgnetv2_b0.ssld_stage1': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b1.ssld_in1k': _cfg(
|
# 'hgnetv2_b1.ssld_in1k': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b1.ssld_stage1': _cfg(
|
# 'hgnetv2_b1.ssld_stage1': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b2.ssld_in1k': _cfg(
|
# 'hgnetv2_b2.ssld_in1k': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b2.ssld_stage1': _cfg(
|
# 'hgnetv2_b2.ssld_stage1': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b3.ssld_in1k': _cfg(
|
# 'hgnetv2_b3.ssld_in1k': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b3.ssld_stage1': _cfg(
|
# 'hgnetv2_b3.ssld_stage1': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b4.ssld_in1k': _cfg(
|
# 'hgnetv2_b4.ssld_in1k': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b4.ssld_stage1': _cfg(
|
# 'hgnetv2_b4.ssld_stage1': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b5.ssld_in1k': _cfg(
|
# 'hgnetv2_b5.ssld_in1k': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b5.ssld_stage1': _cfg(
|
# 'hgnetv2_b5.ssld_stage1': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b6.ssld_in1k': _cfg(
|
# 'hgnetv2_b6.ssld_in1k': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
'hgnetv2_b6.ssld_stage1': _cfg(
|
# 'hgnetv2_b6.ssld_stage1': _cfg(
|
||||||
first_conv='stem.stem1.conv',
|
# first_conv='stem.stem1.conv',
|
||||||
hf_hub_id='timm/'),
|
# hf_hub_id='timm/'),
|
||||||
})
|
# })
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user