fix cls head in hgnet

This commit is contained in:
方曦 2023-12-27 21:26:26 +08:00
parent 56ae8b906d
commit 9dbea3bef6

View File

@ -423,15 +423,21 @@ class PPHGNet(nn.Module):
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.head = ClassifierHead(
num_features=self.num_features,
num_classes=num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
use_last_conv=use_last_conv,
class_expand=class_expand,
use_lab=use_lab
)
if num_classes > 0:
self.head = ClassifierHead(
num_features=self.num_features,
num_classes=num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
use_last_conv=use_last_conv,
class_expand=class_expand,
use_lab=use_lab
)
else:
if global_pool == 'avg':
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
else:
self.head = nn.Identity()
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
@ -608,65 +614,65 @@ def _cfg(url='', **kwargs):
}
default_cfgs = generate_default_cfgs({
'hgnet_tiny.paddle_in1k': _cfg(
first_conv='stem.0.conv',
hf_hub_id='timm/'),
'hgnet_tiny.ssld_in1k': _cfg(
first_conv='stem.0.conv',
hf_hub_id='timm/'),
'hgnet_small.paddle_in1k': _cfg(
first_conv='stem.0.conv',
hf_hub_id='timm/'),
'hgnet_small.ssld_in1k': _cfg(
first_conv='stem.0.conv',
hf_hub_id='timm/'),
'hgnet_base.ssld_in1k': _cfg(
first_conv='stem.0.conv',
hf_hub_id='timm/'),
'hgnetv2_b0.ssld_in1k': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b0.ssld_stage1': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b1.ssld_in1k': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b1.ssld_stage1': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b2.ssld_in1k': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b2.ssld_stage1': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b3.ssld_in1k': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b3.ssld_stage1': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b4.ssld_in1k': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b4.ssld_stage1': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b5.ssld_in1k': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b5.ssld_stage1': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b6.ssld_in1k': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
'hgnetv2_b6.ssld_stage1': _cfg(
first_conv='stem.stem1.conv',
hf_hub_id='timm/'),
})
# default_cfgs = generate_default_cfgs({
# 'hgnet_tiny.paddle_in1k': _cfg(
# first_conv='stem.0.conv',
# hf_hub_id='timm/'),
# 'hgnet_tiny.ssld_in1k': _cfg(
# first_conv='stem.0.conv',
# hf_hub_id='timm/'),
# 'hgnet_small.paddle_in1k': _cfg(
# first_conv='stem.0.conv',
# hf_hub_id='timm/'),
# 'hgnet_small.ssld_in1k': _cfg(
# first_conv='stem.0.conv',
# hf_hub_id='timm/'),
# 'hgnet_base.ssld_in1k': _cfg(
# first_conv='stem.0.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b0.ssld_in1k': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b0.ssld_stage1': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b1.ssld_in1k': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b1.ssld_stage1': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b2.ssld_in1k': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b2.ssld_stage1': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b3.ssld_in1k': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b3.ssld_stage1': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b4.ssld_in1k': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b4.ssld_stage1': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b5.ssld_in1k': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b5.ssld_stage1': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b6.ssld_in1k': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# 'hgnetv2_b6.ssld_stage1': _cfg(
# first_conv='stem.stem1.conv',
# hf_hub_id='timm/'),
# })
@register_model