From 528faa0e04f94dac8260ac8c564e068c943b8359 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 26 Apr 2023 17:46:20 -0700 Subject: [PATCH] Some fixes --- timm/models/efficientnet.py | 16 ++++++++-------- timm/models/hrnet.py | 2 ++ timm/models/res2net.py | 4 ++-- timm/models/sequencer.py | 2 +- timm/models/xception.py | 2 +- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 40bde62b..3a9fc13a 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -961,10 +961,10 @@ default_cfgs = generate_default_cfgs({ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth', hf_hub_id='timm/', input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0), - 'efficientnet_b5.in12k_ft_in1k': _cfg( + 'efficientnet_b5.sw_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, crop_mode='squash'), - 'efficientnet_b5.in12k': _cfg( + 'efficientnet_b5.sw_in12k': _cfg( hf_hub_id='timm/', input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.95, num_classes=11821), 'efficientnet_b6.untrained': _cfg( @@ -1197,27 +1197,27 @@ default_cfgs = generate_default_cfgs({ 'tf_efficientnet_b0.in1k': _cfg( url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth', - hf_hub_id='timm/', + #hf_hub_id='timm/', input_size=(3, 224, 224)), 'tf_efficientnet_b1.in1k': _cfg( url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth', - hf_hub_id='timm/', + #hf_hub_id='timm/', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), 'tf_efficientnet_b2.in1k': _cfg( url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth', - hf_hub_id='timm/', + #hf_hub_id='timm/', input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), 'tf_efficientnet_b3.in1k': _cfg( url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth', - hf_hub_id='timm/', + #hf_hub_id='timm/', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), 'tf_efficientnet_b4.in1k': _cfg( url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth', - hf_hub_id='timm/', + #hf_hub_id='timm/', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), 'tf_efficientnet_b5.in1k': _cfg( url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth', - hf_hub_id='timm/', + #hf_hub_id='timm/', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index d2d32372..db75bc0f 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -815,6 +815,7 @@ class HighResolutionNetFeatures(HighResolutionNet): drop_rate=0.0, feature_location='incre', out_indices=(0, 1, 2, 3, 4), + **kwargs, ): assert feature_location in ('incre', '') super(HighResolutionNetFeatures, self).__init__( @@ -825,6 +826,7 @@ class HighResolutionNetFeatures(HighResolutionNet): global_pool=global_pool, drop_rate=drop_rate, head=feature_location, + **kwargs, ) self.feature_info = FeatureInfo(self.feature_info, out_indices) self._out_idx = {i for i in out_indices} diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 3d5f5049..5804a4e8 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -139,8 +139,8 @@ default_cfgs = generate_default_cfgs({ 'res2net50_26w_8s.in1k': _cfg(hf_hub_id='timm/'), 'res2net101_26w_4s.in1k': _cfg(hf_hub_id='timm/'), 'res2next50.in1k': _cfg(hf_hub_id='timm/'), - 'res2net50d.in1k': _cfg(hf_hub_id='timm/'), - 'res2net101d.in1k': _cfg(hf_hub_id='timm/'), + 'res2net50d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'), + 'res2net101d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'), }) diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index fb2385ae..2899d29e 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -477,7 +477,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.proj', 'classifier': 'head', + 'first_conv': 'stem.proj', 'classifier': 'head.fc', **kwargs } diff --git a/timm/models/xception.py b/timm/models/xception.py index 70db1aa2..14b6e4f1 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -248,7 +248,7 @@ default_cfgs = generate_default_cfgs({ @register_model def legacy_xception(pretrained=False, **kwargs): - return _xception('xception', pretrained=pretrained, **kwargs) + return _xception('legacy_xception', pretrained=pretrained, **kwargs) register_model_deprecations(__name__, {