mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix cls head in hgnet
This commit is contained in:
parent
56ae8b906d
commit
9dbea3bef6
@ -423,6 +423,7 @@ class PPHGNet(nn.Module):
|
||||
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
if num_classes > 0:
|
||||
self.head = ClassifierHead(
|
||||
num_features=self.num_features,
|
||||
num_classes=num_classes,
|
||||
@ -432,6 +433,11 @@ class PPHGNet(nn.Module):
|
||||
class_expand=class_expand,
|
||||
use_lab=use_lab
|
||||
)
|
||||
else:
|
||||
if global_pool == 'avg':
|
||||
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
else:
|
||||
self.head = nn.Identity()
|
||||
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
@ -608,65 +614,65 @@ def _cfg(url='', **kwargs):
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'hgnet_tiny.paddle_in1k': _cfg(
|
||||
first_conv='stem.0.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnet_tiny.ssld_in1k': _cfg(
|
||||
first_conv='stem.0.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnet_small.paddle_in1k': _cfg(
|
||||
first_conv='stem.0.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnet_small.ssld_in1k': _cfg(
|
||||
first_conv='stem.0.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnet_base.ssld_in1k': _cfg(
|
||||
first_conv='stem.0.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b0.ssld_in1k': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b0.ssld_stage1': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b1.ssld_in1k': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b1.ssld_stage1': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b2.ssld_in1k': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b2.ssld_stage1': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b3.ssld_in1k': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b3.ssld_stage1': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b4.ssld_in1k': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b4.ssld_stage1': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b5.ssld_in1k': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b5.ssld_stage1': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b6.ssld_in1k': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
'hgnetv2_b6.ssld_stage1': _cfg(
|
||||
first_conv='stem.stem1.conv',
|
||||
hf_hub_id='timm/'),
|
||||
})
|
||||
# default_cfgs = generate_default_cfgs({
|
||||
# 'hgnet_tiny.paddle_in1k': _cfg(
|
||||
# first_conv='stem.0.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnet_tiny.ssld_in1k': _cfg(
|
||||
# first_conv='stem.0.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnet_small.paddle_in1k': _cfg(
|
||||
# first_conv='stem.0.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnet_small.ssld_in1k': _cfg(
|
||||
# first_conv='stem.0.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnet_base.ssld_in1k': _cfg(
|
||||
# first_conv='stem.0.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b0.ssld_in1k': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b0.ssld_stage1': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b1.ssld_in1k': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b1.ssld_stage1': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b2.ssld_in1k': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b2.ssld_stage1': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b3.ssld_in1k': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b3.ssld_stage1': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b4.ssld_in1k': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b4.ssld_stage1': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b5.ssld_in1k': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b5.ssld_stage1': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b6.ssld_in1k': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# 'hgnetv2_b6.ssld_stage1': _cfg(
|
||||
# first_conv='stem.stem1.conv',
|
||||
# hf_hub_id='timm/'),
|
||||
# })
|
||||
|
||||
|
||||
@register_model
|
||||
|
Loading…
x
Reference in New Issue
Block a user