From 9dbea3bef69d05d72906f7bf22cca24687b5884d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E6=9B=A6?= Date: Wed, 27 Dec 2023 21:26:26 +0800 Subject: [PATCH] fix cls head in hgnet --- timm/models/hgnet.py | 142 ++++++++++++++++++++++--------------------- 1 file changed, 74 insertions(+), 68 deletions(-) diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index 3d25e8c8..450c4bfc 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -423,15 +423,21 @@ class PPHGNet(nn.Module): self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.head = ClassifierHead( - num_features=self.num_features, - num_classes=num_classes, - pool_type=global_pool, - drop_rate=drop_rate, - use_last_conv=use_last_conv, - class_expand=class_expand, - use_lab=use_lab - ) + if num_classes > 0: + self.head = ClassifierHead( + num_features=self.num_features, + num_classes=num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + use_last_conv=use_last_conv, + class_expand=class_expand, + use_lab=use_lab + ) + else: + if global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + self.head = nn.Identity() for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): @@ -608,65 +614,65 @@ def _cfg(url='', **kwargs): } -default_cfgs = generate_default_cfgs({ - 'hgnet_tiny.paddle_in1k': _cfg( - first_conv='stem.0.conv', - hf_hub_id='timm/'), - 'hgnet_tiny.ssld_in1k': _cfg( - first_conv='stem.0.conv', - hf_hub_id='timm/'), - 'hgnet_small.paddle_in1k': _cfg( - first_conv='stem.0.conv', - hf_hub_id='timm/'), - 'hgnet_small.ssld_in1k': _cfg( - first_conv='stem.0.conv', - hf_hub_id='timm/'), - 'hgnet_base.ssld_in1k': _cfg( - first_conv='stem.0.conv', - hf_hub_id='timm/'), - 'hgnetv2_b0.ssld_in1k': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b0.ssld_stage1': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b1.ssld_in1k': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b1.ssld_stage1': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b2.ssld_in1k': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b2.ssld_stage1': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b3.ssld_in1k': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b3.ssld_stage1': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b4.ssld_in1k': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b4.ssld_stage1': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b5.ssld_in1k': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b5.ssld_stage1': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b6.ssld_in1k': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), - 'hgnetv2_b6.ssld_stage1': _cfg( - first_conv='stem.stem1.conv', - hf_hub_id='timm/'), -}) +# default_cfgs = generate_default_cfgs({ +# 'hgnet_tiny.paddle_in1k': _cfg( +# first_conv='stem.0.conv', +# hf_hub_id='timm/'), +# 'hgnet_tiny.ssld_in1k': _cfg( +# first_conv='stem.0.conv', +# hf_hub_id='timm/'), +# 'hgnet_small.paddle_in1k': _cfg( +# first_conv='stem.0.conv', +# hf_hub_id='timm/'), +# 'hgnet_small.ssld_in1k': _cfg( +# first_conv='stem.0.conv', +# hf_hub_id='timm/'), +# 'hgnet_base.ssld_in1k': _cfg( +# first_conv='stem.0.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b0.ssld_in1k': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b0.ssld_stage1': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b1.ssld_in1k': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b1.ssld_stage1': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b2.ssld_in1k': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b2.ssld_stage1': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b3.ssld_in1k': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b3.ssld_stage1': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b4.ssld_in1k': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b4.ssld_stage1': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b5.ssld_in1k': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b5.ssld_stage1': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b6.ssld_in1k': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# 'hgnetv2_b6.ssld_stage1': _cfg( +# first_conv='stem.stem1.conv', +# hf_hub_id='timm/'), +# }) @register_model