mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1785 from huggingface/mw-more
More models w/ multi-weight support, adding to HF hub
This commit is contained in:
commit
46df4fe633
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@ -48,7 +48,6 @@ jobs:
|
||||
- name: Install requirements
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
|
||||
- name: Run tests on Windows
|
||||
if: startsWith(matrix.os, 'windows')
|
||||
env:
|
||||
|
@ -80,7 +80,6 @@ Then install the remaining dependencies:
|
||||
```
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install -r requirements-dev.txt # for testing
|
||||
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
|
||||
python -m pip install -e .
|
||||
```
|
||||
|
||||
|
@ -44,7 +44,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
|
||||
from .split_attn import SplitAttn
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
|
@ -78,6 +78,7 @@ _ACT_LAYER_DEFAULT = dict(
|
||||
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
|
||||
hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
|
||||
hard_mish=HardMish,
|
||||
identity=nn.Identity,
|
||||
)
|
||||
|
||||
_ACT_LAYER_JIT = dict(
|
||||
|
@ -40,6 +40,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
||||
track_running_stats=True,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
act_params=None, # FIXME not the final approach
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
device=None,
|
||||
@ -59,6 +60,8 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
if act_params is not None:
|
||||
act_args['negative_slope'] = act_params
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
|
@ -17,7 +17,7 @@ class SpaceToDepth(nn.Module):
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class SpaceToDepthJit(object):
|
||||
class SpaceToDepthJit:
|
||||
def __call__(self, x: torch.Tensor):
|
||||
# assuming hard-coded that block_size==4 for acceleration
|
||||
N, C, H, W = x.size()
|
||||
|
@ -314,7 +314,7 @@ def push_to_hf_hub(
|
||||
def generate_readme(model_card: dict, model_name: str):
|
||||
readme_text = "---\n"
|
||||
readme_text += "tags:\n- image-classification\n- timm\n"
|
||||
readme_text += "library_tag: timm\n"
|
||||
readme_text += "library_name: timm\n"
|
||||
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
|
||||
if 'details' in model_card and 'Dataset' in model_card['details']:
|
||||
readme_text += 'datasets:\n'
|
||||
|
@ -14,76 +14,12 @@ Hacked together by / copyright Ross Wightman, 2021.
|
||||
"""
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.95, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
||||
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# GPU-Efficient (ResNet) weights
|
||||
'botnet26t_256': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'sebotnet33ts_256': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
||||
'botnet50ts_256': _cfg(
|
||||
url='',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'eca_botnext26ts_256': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
|
||||
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
'halonet26t': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
'sehalonet33ts': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
'halonet50ts': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
'eca_halonext26ts': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
|
||||
'lambda_resnet26t': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
|
||||
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
||||
'lambda_resnet50ts': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
|
||||
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'lambda_resnet26rpt_256': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
||||
|
||||
'haloregnetz_b': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
|
||||
|
||||
'lamhalobotnet50ts_256': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'halo2botnet50ts_256': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
}
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
|
||||
botnet26t=ByoModelCfg(
|
||||
@ -329,7 +265,71 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
ByobNet, variant, pretrained,
|
||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.95, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
||||
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# GPU-Efficient (ResNet) weights
|
||||
'botnet26t_256.c1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'sebotnet33ts_256.a1h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
||||
'botnet50ts_256.untrained': _cfg(
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'eca_botnext26ts_256.c1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
|
||||
'halonet_h1.untrained': _cfg(input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
'halonet26t.a1h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
||||
'sehalonet33ts.ra2_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
'halonet50ts.a1h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
'eca_halonext26ts.c1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
||||
|
||||
'lambda_resnet26t.c1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
|
||||
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
||||
'lambda_resnet50ts.a1h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
|
||||
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'lambda_resnet26rpt_256.c1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
||||
|
||||
'haloregnetz_b.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
|
||||
|
||||
'lamhalobotnet50ts_256.a1h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'halo2botnet50ts_256.a1h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
|
||||
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -15,39 +15,21 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import MATCH_PREV_GROUP
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['DenseNet']
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'features.conv0', 'classifier': 'classifier',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'densenet121': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'),
|
||||
'densenet121d': _cfg(url=''),
|
||||
'densenetblur121d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'),
|
||||
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
|
||||
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
|
||||
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
|
||||
'densenet264': _cfg(url=''),
|
||||
'densenet264d_iabn': _cfg(url=''),
|
||||
'tv_densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'),
|
||||
}
|
||||
|
||||
|
||||
class DenseLayer(nn.Module):
|
||||
def __init__(
|
||||
self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
|
||||
drop_rate=0., memory_efficient=False):
|
||||
self,
|
||||
num_input_features,
|
||||
growth_rate,
|
||||
bn_size,
|
||||
norm_layer=BatchNormAct2d,
|
||||
drop_rate=0.,
|
||||
memory_efficient=False,
|
||||
):
|
||||
super(DenseLayer, self).__init__()
|
||||
self.add_module('norm1', norm_layer(num_input_features)),
|
||||
self.add_module('conv1', nn.Conv2d(
|
||||
@ -145,7 +127,13 @@ class DenseBlock(nn.ModuleDict):
|
||||
|
||||
|
||||
class DenseTransition(nn.Sequential):
|
||||
def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None):
|
||||
def __init__(
|
||||
self,
|
||||
num_input_features,
|
||||
num_output_features,
|
||||
norm_layer=BatchNormAct2d,
|
||||
aa_layer=None,
|
||||
):
|
||||
super(DenseTransition, self).__init__()
|
||||
self.add_module('norm', norm_layer(num_input_features))
|
||||
self.add_module('conv', nn.Conv2d(
|
||||
@ -324,9 +312,35 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
||||
kwargs['growth_rate'] = growth_rate
|
||||
kwargs['block_config'] = block_config
|
||||
return build_model_with_cfg(
|
||||
DenseNet, variant, pretrained,
|
||||
DenseNet,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'features.conv0', 'classifier': 'classifier',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'densenet121.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'),
|
||||
'densenet121d': _cfg(url=''),
|
||||
'densenetblur121d.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'),
|
||||
'densenet169.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
|
||||
'densenet201.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
|
||||
'densenet161.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
|
||||
'densenet264.untrained': _cfg(url=''),
|
||||
'densenet121.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'),
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
@ -400,22 +414,3 @@ def densenet264(pretrained=False, **kwargs):
|
||||
'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def densenet264d_iabn(pretrained=False, **kwargs):
|
||||
r"""Densenet-264 model with deep stem and Inplace-ABN
|
||||
"""
|
||||
model = _create_densenet(
|
||||
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
|
||||
norm_layer='iabn', act_layer='leaky_relu', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tv_densenet121(pretrained=False, **kwargs):
|
||||
r"""Densenet-121 model with original Torchvision weights, from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
model = _create_densenet(
|
||||
'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
@ -17,39 +17,11 @@ import torch.nn.functional as F
|
||||
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['DPN']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
|
||||
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'dpn48b': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'dpn68': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
|
||||
'dpn68b': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'dpn92': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
|
||||
'dpn98': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
|
||||
'dpn131': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'),
|
||||
'dpn107': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth')
|
||||
}
|
||||
|
||||
|
||||
class CatBnAct(nn.Module):
|
||||
def __init__(self, in_chs, norm_layer=BatchNormAct2d):
|
||||
super(CatBnAct, self).__init__()
|
||||
@ -310,9 +282,42 @@ class DPN(nn.Module):
|
||||
|
||||
def _create_dpn(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
DPN, variant, pretrained,
|
||||
DPN,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(feature_concat=True, flatten_sequential=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
|
||||
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'dpn48b.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'dpn68.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
|
||||
'dpn68b.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'dpn68b.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'),
|
||||
'dpn92.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
|
||||
'dpn98.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
|
||||
'dpn131.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'),
|
||||
'dpn107.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth')
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -7,32 +7,12 @@ from ._builder import build_model_with_cfg
|
||||
from ._builder import pretrained_cfg_for_features
|
||||
from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
|
||||
|
||||
__all__ = [] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'hardcorenas_a': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_a_green_38ms_75_9-31dc7186.pth'),
|
||||
'hardcorenas_b': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_b_green_40ms_76_5-32d91ff2.pth'),
|
||||
'hardcorenas_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_c_green_44ms_77_1-631a0983.pth'),
|
||||
'hardcorenas_d': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_d_green_50ms_77_4-998d9d7a.pth'),
|
||||
'hardcorenas_e': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_e_green_55ms_77_9-482886a3.pth'),
|
||||
'hardcorenas_f': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_f_green_60ms_78_1-14b9e780.pth'),
|
||||
}
|
||||
|
||||
|
||||
def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
|
||||
"""Creates a hardcorenas model
|
||||
|
||||
@ -60,15 +40,44 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
|
||||
kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool')
|
||||
model_cls = MobileNetV3Features
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained,
|
||||
model_cls,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_strict=not features_only,
|
||||
kwargs_filter=kwargs_filter,
|
||||
**model_kwargs)
|
||||
**model_kwargs,
|
||||
)
|
||||
if features_only:
|
||||
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'hardcorenas_a.green_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_a_green_38ms_75_9-31dc7186.pth'),
|
||||
'hardcorenas_b.green_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_b_green_40ms_76_5-32d91ff2.pth'),
|
||||
'hardcorenas_c.green_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_c_green_44ms_77_1-631a0983.pth'),
|
||||
'hardcorenas_d.green_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_d_green_50ms_77_4-998d9d7a.pth'),
|
||||
'hardcorenas_e.green_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_e_green_55ms_77_9-482886a3.pth'),
|
||||
'hardcorenas_f.green_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_f_green_60ms_78_1-14b9e780.pth'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def hardcorenas_a(pretrained=False, **kwargs):
|
||||
""" hardcorenas_A """
|
||||
|
@ -6,104 +6,62 @@ Original model: https://github.com/mrT23/TResNet
|
||||
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule
|
||||
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\
|
||||
ConvNormActAa, ConvNormAct, DropPath
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
|
||||
__all__ = ['TResNet'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': (0., 0., 0.), 'std': (1., 1., 1.),
|
||||
'first_conv': 'body.conv1.0', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'tresnet_m': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_1k_miil_83_1-d236afcb.pth'),
|
||||
'tresnet_m_miil_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_miil_in21k-901b6ed4.pth', num_classes=11221),
|
||||
'tresnet_l': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'),
|
||||
'tresnet_xl': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
|
||||
'tresnet_m_448': _cfg(
|
||||
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
|
||||
'tresnet_l_448': _cfg(
|
||||
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
|
||||
'tresnet_xl_448': _cfg(
|
||||
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth'),
|
||||
|
||||
'tresnet_v2_l': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_v2_83_9-f36e4445.pth'),
|
||||
}
|
||||
|
||||
|
||||
def IABN2Float(module: nn.Module) -> nn.Module:
|
||||
"""If `module` is IABN don't use half precision."""
|
||||
if isinstance(module, InplaceAbn):
|
||||
module.float()
|
||||
for child in module.children():
|
||||
IABN2Float(child)
|
||||
return module
|
||||
|
||||
|
||||
def conv2d_iabn(ni, nf, stride, kernel_size=3, groups=1, act_layer="leaky_relu", act_param=1e-2):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False),
|
||||
InplaceAbn(nf, act_layer=act_layer, act_param=act_param)
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None):
|
||||
def __init__(
|
||||
self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
use_se=True,
|
||||
aa_layer=None,
|
||||
drop_path_rate=0.
|
||||
):
|
||||
super(BasicBlock, self).__init__()
|
||||
if stride == 1:
|
||||
self.conv1 = conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3)
|
||||
else:
|
||||
if aa_layer is None:
|
||||
self.conv1 = conv2d_iabn(inplanes, planes, stride=2, act_param=1e-3)
|
||||
else:
|
||||
self.conv1 = nn.Sequential(
|
||||
conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3),
|
||||
aa_layer(channels=planes, filt_size=3, stride=2))
|
||||
|
||||
self.conv2 = conv2d_iabn(planes, planes, stride=1, act_layer="identity")
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
|
||||
|
||||
if stride == 1:
|
||||
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer)
|
||||
else:
|
||||
self.conv1 = ConvNormActAa(
|
||||
inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
|
||||
|
||||
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None)
|
||||
self.act = nn.ReLU(inplace=True)
|
||||
|
||||
rd_chs = max(planes * self.expansion // 4, 64)
|
||||
self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
else:
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
|
||||
out = out + shortcut
|
||||
out = self.relu(out)
|
||||
out = self.drop_path(out) + shortcut
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
@ -111,47 +69,51 @@ class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(
|
||||
self, inplanes, planes, stride=1, downsample=None, use_se=True,
|
||||
act_layer="leaky_relu", aa_layer=None):
|
||||
self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
use_se=True,
|
||||
act_layer=None,
|
||||
aa_layer=None,
|
||||
drop_path_rate=0.,
|
||||
):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = conv2d_iabn(
|
||||
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
act_layer = act_layer or partial(nn.LeakyReLU, negative_slope=1e-3)
|
||||
|
||||
self.conv1 = ConvNormAct(
|
||||
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
|
||||
if stride == 1:
|
||||
self.conv2 = conv2d_iabn(
|
||||
planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3)
|
||||
self.conv2 = ConvNormAct(
|
||||
planes, planes, kernel_size=3, stride=1, act_layer=act_layer)
|
||||
else:
|
||||
if aa_layer is None:
|
||||
self.conv2 = conv2d_iabn(
|
||||
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, act_param=1e-3)
|
||||
else:
|
||||
self.conv2 = nn.Sequential(
|
||||
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
|
||||
aa_layer(channels=planes, filt_size=3, stride=2))
|
||||
self.conv2 = ConvNormActAa(
|
||||
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
|
||||
|
||||
reduction_chs = max(planes * self.expansion // 8, 64)
|
||||
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
|
||||
|
||||
self.conv3 = conv2d_iabn(
|
||||
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
|
||||
self.conv3 = ConvNormAct(
|
||||
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.act = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
else:
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
out = self.conv3(out)
|
||||
out = out + shortcut # no inplace
|
||||
out = self.drop_path(out) + shortcut
|
||||
out = self.act(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -165,12 +127,15 @@ class TResNet(nn.Module):
|
||||
v2=False,
|
||||
global_pool='fast',
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
):
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
self.grad_checkpointing = False
|
||||
super(TResNet, self).__init__()
|
||||
|
||||
aa_layer = BlurPool2d
|
||||
act_layer = nn.LeakyReLU
|
||||
|
||||
# TResnet stages
|
||||
self.inplanes = int(64 * width_factor)
|
||||
@ -179,24 +144,30 @@ class TResNet(nn.Module):
|
||||
self.inplanes = self.inplanes // 8 * 8
|
||||
self.planes = self.planes // 8 * 8
|
||||
|
||||
conv1 = conv2d_iabn(in_chans * 16, self.planes, stride=1, kernel_size=3)
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
||||
conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer)
|
||||
layer1 = self._make_layer(
|
||||
Bottleneck if v2 else BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer)
|
||||
Bottleneck if v2 else BasicBlock,
|
||||
self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0])
|
||||
layer2 = self._make_layer(
|
||||
Bottleneck if v2 else BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer)
|
||||
Bottleneck if v2 else BasicBlock,
|
||||
self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1])
|
||||
layer3 = self._make_layer(
|
||||
Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer)
|
||||
Bottleneck,
|
||||
self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2])
|
||||
layer4 = self._make_layer(
|
||||
Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer)
|
||||
Bottleneck,
|
||||
self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3])
|
||||
|
||||
# body
|
||||
self.body = nn.Sequential(OrderedDict([
|
||||
('SpaceToDepth', SpaceToDepthModule()),
|
||||
('s2d', SpaceToDepth()),
|
||||
('conv1', conv1),
|
||||
('layer1', layer1),
|
||||
('layer2', layer2),
|
||||
('layer3', layer3),
|
||||
('layer4', layer4)]))
|
||||
('layer4', layer4),
|
||||
]))
|
||||
|
||||
self.feature_info = [
|
||||
dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
|
||||
@ -214,37 +185,39 @@ class TResNet(nn.Module):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InplaceAbn):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
|
||||
# residual connections special initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, BasicBlock):
|
||||
m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero
|
||||
nn.init.zeros_(m.conv2.bn.weight)
|
||||
if isinstance(m, Bottleneck):
|
||||
m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero
|
||||
if isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
nn.init.zeros_(m.conv3.bn.weight)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None):
|
||||
def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None, drop_path_rate=0.):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
layers = []
|
||||
if stride == 2:
|
||||
# avg pooling before 1x1 conv
|
||||
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
|
||||
layers += [conv2d_iabn(
|
||||
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, act_layer="identity")]
|
||||
layers += [ConvNormAct(
|
||||
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)]
|
||||
downsample = nn.Sequential(*layers)
|
||||
|
||||
layers = []
|
||||
layers.append(block(
|
||||
self.inplanes, planes, stride, downsample, use_se=use_se, aa_layer=aa_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer))
|
||||
for i in range(blocks):
|
||||
layers.append(block(
|
||||
self.inplanes,
|
||||
planes,
|
||||
stride=stride if i == 0 else 1,
|
||||
downsample=downsample if i == 0 else None,
|
||||
use_se=use_se,
|
||||
aa_layer=aa_layer,
|
||||
drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate,
|
||||
))
|
||||
self.inplanes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@torch.jit.ignore
|
||||
@ -254,18 +227,28 @@ class TResNet(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
assert not enable, 'gradient checkpointing not supported'
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='fast'):
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
return self.body(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = self.body.s2d(x)
|
||||
x = self.body.conv1(x)
|
||||
x = checkpoint_seq([
|
||||
self.body.layer1,
|
||||
self.body.layer2,
|
||||
self.body.layer3,
|
||||
self.body.layer4],
|
||||
x, flatten=True)
|
||||
else:
|
||||
x = self.body(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return x if pre_logits else self.head(x)
|
||||
@ -276,11 +259,74 @@ class TResNet(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
if 'body.conv1.conv.weight' in state_dict:
|
||||
return state_dict
|
||||
|
||||
import re
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
k = re.sub(r'conv(\d+)\.0.0', lambda x: f'conv{int(x.group(1))}.conv', k)
|
||||
k = re.sub(r'conv(\d+)\.0.1', lambda x: f'conv{int(x.group(1))}.bn', k)
|
||||
k = re.sub(r'conv(\d+)\.0', lambda x: f'conv{int(x.group(1))}.conv', k)
|
||||
k = re.sub(r'conv(\d+)\.1', lambda x: f'conv{int(x.group(1))}.bn', k)
|
||||
k = re.sub(r'downsample\.(\d+)\.0', lambda x: f'downsample.{int(x.group(1))}.conv', k)
|
||||
k = re.sub(r'downsample\.(\d+)\.1', lambda x: f'downsample.{int(x.group(1))}.bn', k)
|
||||
if k.endswith('bn.weight'):
|
||||
# convert weight from inplace_abn to batchnorm
|
||||
v = v.abs().add(1e-5)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_tresnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
TResNet, variant, pretrained,
|
||||
TResNet,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': (0., 0., 0.), 'std': (1., 1., 1.),
|
||||
'first_conv': 'body.conv1.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'tresnet_m.miil_in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_1k_miil_83_1-d236afcb.pth'),
|
||||
'tresnet_m.miil_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_miil_in21k-901b6ed4.pth',
|
||||
num_classes=11221),
|
||||
'tresnet_m.miil_in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_80_8-dbc13962.pth'),
|
||||
'tresnet_l.miil_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'),
|
||||
'tresnet_xl.miil_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
|
||||
'tresnet_m.miil_in1k_448': _cfg(
|
||||
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
|
||||
'tresnet_l.miil_in1k_448': _cfg(
|
||||
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
|
||||
'tresnet_xl.miil_in1k_448': _cfg(
|
||||
input_size=(3, 448, 448), pool_size=(14, 14),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth'),
|
||||
|
||||
'tresnet_v2_l.miil_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_v2_83_9-f36e4445.pth'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
@ -289,24 +335,12 @@ def tresnet_m(pretrained=False, **kwargs):
|
||||
return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_m_miil_in21k(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
|
||||
return _create_tresnet('tresnet_m_miil_in21k', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_l(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs)
|
||||
return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_v2_l(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True, **kwargs)
|
||||
return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_xl(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs)
|
||||
@ -314,18 +348,14 @@ def tresnet_xl(pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_m_448(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
|
||||
return _create_tresnet('tresnet_m_448', pretrained=pretrained, **model_kwargs)
|
||||
def tresnet_v2_l(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True, **kwargs)
|
||||
return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_l_448(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs)
|
||||
return _create_tresnet('tresnet_l_448', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_xl_448(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs)
|
||||
return _create_tresnet('tresnet_xl_448', pretrained=pretrained, **model_kwargs)
|
||||
register_model_deprecations(__name__, {
|
||||
'tresnet_m_miil_in21k': 'tresnet_m.miil_in21k',
|
||||
'tresnet_m_448': 'tresnet_m.miil_in1k_448',
|
||||
'tresnet_l_448': 'tresnet_l.miil_in1k_448',
|
||||
'tresnet_xl_448': 'tresnet_xl.miil_in1k_448',
|
||||
})
|
@ -18,152 +18,14 @@ import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
|
||||
create_attn, create_norm_act_layer, get_norm_act_layer
|
||||
create_attn, create_norm_act_layer
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['VovNet'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
|
||||
# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
|
||||
model_cfgs = dict(
|
||||
vovnet39a=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=False,
|
||||
depthwise=False,
|
||||
attn='',
|
||||
),
|
||||
vovnet57a=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 4, 3],
|
||||
residual=False,
|
||||
depthwise=False,
|
||||
attn='',
|
||||
|
||||
),
|
||||
ese_vovnet19b_slim_dw=dict(
|
||||
stem_chs=[64, 64, 64],
|
||||
stage_conv_chs=[64, 80, 96, 112],
|
||||
stage_out_chs=[112, 256, 384, 512],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=True,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet19b_dw=dict(
|
||||
stem_chs=[64, 64, 64],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=True,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet19b_slim=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[64, 80, 96, 112],
|
||||
stage_out_chs=[112, 256, 384, 512],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet19b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet39b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet57b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 4, 3],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet99b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 3, 9, 3],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
eca_vovnet39b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='eca',
|
||||
),
|
||||
)
|
||||
model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
|
||||
model_cfgs['ese_vovnet99b_iabn'] = model_cfgs['ese_vovnet99b']
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
vovnet39a=_cfg(url=''),
|
||||
vovnet57a=_cfg(url=''),
|
||||
ese_vovnet19b_slim_dw=_cfg(url=''),
|
||||
ese_vovnet19b_dw=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'),
|
||||
ese_vovnet19b_slim=_cfg(url=''),
|
||||
ese_vovnet39b=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'),
|
||||
ese_vovnet57b=_cfg(url=''),
|
||||
ese_vovnet99b=_cfg(url=''),
|
||||
eca_vovnet39b=_cfg(url=''),
|
||||
ese_vovnet39b_evos=_cfg(url=''),
|
||||
ese_vovnet99b_iabn=_cfg(url=''),
|
||||
)
|
||||
|
||||
|
||||
class SequentialAppendList(nn.Sequential):
|
||||
def __init__(self, *args):
|
||||
super(SequentialAppendList, self).__init__(*args)
|
||||
@ -405,12 +267,151 @@ class VovNet(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
|
||||
# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
|
||||
model_cfgs = dict(
|
||||
vovnet39a=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=False,
|
||||
depthwise=False,
|
||||
attn='',
|
||||
),
|
||||
vovnet57a=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 4, 3],
|
||||
residual=False,
|
||||
depthwise=False,
|
||||
attn='',
|
||||
|
||||
),
|
||||
ese_vovnet19b_slim_dw=dict(
|
||||
stem_chs=[64, 64, 64],
|
||||
stage_conv_chs=[64, 80, 96, 112],
|
||||
stage_out_chs=[112, 256, 384, 512],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=True,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet19b_dw=dict(
|
||||
stem_chs=[64, 64, 64],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=True,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet19b_slim=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[64, 80, 96, 112],
|
||||
stage_out_chs=[112, 256, 384, 512],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet19b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet39b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet57b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 4, 3],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet99b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 3, 9, 3],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
eca_vovnet39b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='eca',
|
||||
),
|
||||
)
|
||||
model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
|
||||
|
||||
|
||||
def _create_vovnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
VovNet, variant, pretrained,
|
||||
VovNet,
|
||||
variant,
|
||||
pretrained,
|
||||
model_cfg=model_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'vovnet39a.untrained': _cfg(url=''),
|
||||
'vovnet57a.untrained': _cfg(url=''),
|
||||
'ese_vovnet19b_slim_dw.untrained': _cfg(url=''),
|
||||
'ese_vovnet19b_dw.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'),
|
||||
'ese_vovnet19b_slim.untrained': _cfg(url=''),
|
||||
'ese_vovnet39b.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'),
|
||||
'ese_vovnet57b.untrained': _cfg(url=''),
|
||||
'ese_vovnet99b.untrained': _cfg(url=''),
|
||||
'eca_vovnet39b.untrained': _cfg(url=''),
|
||||
'ese_vovnet39b_evos.untrained': _cfg(url=''),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
@ -465,10 +466,3 @@ def ese_vovnet39b_evos(pretrained=False, **kwargs):
|
||||
def norm_act_fn(num_features, **nkwargs):
|
||||
return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs)
|
||||
return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ese_vovnet99b_iabn(pretrained=False, **kwargs):
|
||||
norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
|
||||
return _create_vovnet(
|
||||
'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs)
|
||||
|
@ -15,47 +15,23 @@ from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act
|
||||
from timm.layers.helpers import to_3tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['XceptionAligned']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
|
||||
'crop_pct': 0.903, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
xception41=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
|
||||
xception65=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65_ra3-1447db8d.pth',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
xception71=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
|
||||
|
||||
xception41p=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception41p_ra3-33195bc8.pth',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
xception65p=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65p_ra3-3c6114e4.pth',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='',
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding='',
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
@ -84,8 +60,17 @@ class SeparableConv2d(nn.Module):
|
||||
|
||||
class PreSeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='',
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, first_act=True):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding='',
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
first_act=True,
|
||||
):
|
||||
super(PreSeparableConv2d, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
|
||||
self.kernel_size = kernel_size
|
||||
@ -109,8 +94,17 @@ class PreSeparableConv2d(nn.Module):
|
||||
|
||||
class XceptionModule(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
|
||||
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
pad_type='',
|
||||
start_with_relu=True,
|
||||
no_skip=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=None,
|
||||
):
|
||||
super(XceptionModule, self).__init__()
|
||||
out_chs = to_3tuple(out_chs)
|
||||
self.in_channels = in_chs
|
||||
@ -144,8 +138,16 @@ class XceptionModule(nn.Module):
|
||||
|
||||
class PreXceptionModule(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
|
||||
no_skip=False, act_layer=nn.ReLU, norm_layer=None):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
pad_type='',
|
||||
no_skip=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=None,
|
||||
):
|
||||
super(PreXceptionModule, self).__init__()
|
||||
out_chs = to_3tuple(out_chs)
|
||||
self.in_channels = in_chs
|
||||
@ -160,8 +162,16 @@ class PreXceptionModule(nn.Module):
|
||||
self.stack = nn.Sequential()
|
||||
for i in range(3):
|
||||
self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d(
|
||||
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
|
||||
act_layer=act_layer, norm_layer=norm_layer, first_act=i > 0))
|
||||
in_chs,
|
||||
out_chs[i],
|
||||
3,
|
||||
stride=stride if i == 2 else 1,
|
||||
dilation=dilation,
|
||||
padding=pad_type,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
first_act=i > 0,
|
||||
))
|
||||
in_chs = out_chs[i]
|
||||
|
||||
def forward(self, x):
|
||||
@ -178,8 +188,17 @@ class XceptionAligned(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
|
||||
self,
|
||||
block_cfg,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
output_stride=32,
|
||||
preact=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
drop_rate=0.,
|
||||
global_pool='avg',
|
||||
):
|
||||
super(XceptionAligned, self).__init__()
|
||||
assert output_stride in (8, 16, 32)
|
||||
self.num_classes = num_classes
|
||||
@ -216,7 +235,11 @@ class XceptionAligned(nn.Module):
|
||||
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
|
||||
self.act = act_layer(inplace=True) if preact else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
in_features=self.num_features,
|
||||
num_classes=num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
@ -234,7 +257,7 @@ class XceptionAligned(nn.Module):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
@ -256,9 +279,47 @@ class XceptionAligned(nn.Module):
|
||||
|
||||
def _xception(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
XceptionAligned, variant, pretrained,
|
||||
XceptionAligned,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
|
||||
'crop_pct': 0.903, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'xception65.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65_ra3-1447db8d.pth',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
|
||||
'xception41.tf_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
|
||||
'xception65.tf_in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
|
||||
'xception71.tf_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
|
||||
|
||||
'xception41p.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception41p_ra3-33195bc8.pth',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
'xception65p.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65p_ra3-3c6114e4.pth',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
Loading…
x
Reference in New Issue
Block a user