More models w/ multi-weight support, moving to HF hub. Removing inplace_abn from all models including TResNet

This commit is contained in:
Ross Wightman 2023-04-20 22:41:39 -07:00
parent 2aabaef039
commit a08e5aed1d
14 changed files with 614 additions and 518 deletions

View File

@ -48,7 +48,6 @@ jobs:
- name: Install requirements - name: Install requirements
run: | run: |
pip install -r requirements.txt pip install -r requirements.txt
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
- name: Run tests on Windows - name: Run tests on Windows
if: startsWith(matrix.os, 'windows') if: startsWith(matrix.os, 'windows')
env: env:

View File

@ -80,7 +80,6 @@ Then install the remaining dependencies:
``` ```
python -m pip install -r requirements.txt python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt # for testing python -m pip install -r requirements-dev.txt # for testing
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
python -m pip install -e . python -m pip install -e .
``` ```

View File

@ -44,7 +44,7 @@ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_e
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
from .split_attn import SplitAttn from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame

View File

@ -78,6 +78,7 @@ _ACT_LAYER_DEFAULT = dict(
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
hard_swish=nn.Hardswish if _has_hardswish else HardSwish, hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
hard_mish=HardMish, hard_mish=HardMish,
identity=nn.Identity,
) )
_ACT_LAYER_JIT = dict( _ACT_LAYER_JIT = dict(

View File

@ -40,6 +40,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
track_running_stats=True, track_running_stats=True,
apply_act=True, apply_act=True,
act_layer=nn.ReLU, act_layer=nn.ReLU,
act_params=None, # FIXME not the final approach
inplace=True, inplace=True,
drop_layer=None, drop_layer=None,
device=None, device=None,
@ -59,6 +60,8 @@ class BatchNormAct2d(nn.BatchNorm2d):
act_layer = get_act_layer(act_layer) # string -> nn.Module act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act: if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {} act_args = dict(inplace=True) if inplace else {}
if act_params is not None:
act_args['negative_slope'] = act_params
self.act = act_layer(**act_args) self.act = act_layer(**act_args)
else: else:
self.act = nn.Identity() self.act = nn.Identity()

View File

@ -17,7 +17,7 @@ class SpaceToDepth(nn.Module):
@torch.jit.script @torch.jit.script
class SpaceToDepthJit(object): class SpaceToDepthJit:
def __call__(self, x: torch.Tensor): def __call__(self, x: torch.Tensor):
# assuming hard-coded that block_size==4 for acceleration # assuming hard-coded that block_size==4 for acceleration
N, C, H, W = x.size() N, C, H, W = x.size()

View File

@ -314,7 +314,7 @@ def push_to_hf_hub(
def generate_readme(model_card: dict, model_name: str): def generate_readme(model_card: dict, model_name: str):
readme_text = "---\n" readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n" readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n" readme_text += "library_name: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']: if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n' readme_text += 'datasets:\n'

View File

@ -14,76 +14,12 @@ Hacked together by / copyright Ross Wightman, 2021.
""" """
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._registry import register_model from ._registry import register_model, generate_default_cfgs
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
__all__ = [] __all__ = []
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
**kwargs
}
default_cfgs = {
# GPU-Efficient (ResNet) weights
'botnet26t_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'sebotnet33ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'botnet50ts_256': _cfg(
url='',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_botnext26ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'sehalonet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'halonet50ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'eca_halonext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'lambda_resnet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'lambda_resnet50ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet26rpt_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'haloregnetz_b': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
'lamhalobotnet50ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halo2botnet50ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
}
model_cfgs = dict( model_cfgs = dict(
botnet26t=ByoModelCfg( botnet26t=ByoModelCfg(
@ -329,7 +265,71 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
ByobNet, variant, pretrained, ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**kwargs) **kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
**kwargs
}
default_cfgs = generate_default_cfgs({
# GPU-Efficient (ResNet) weights
'botnet26t_256.c1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'sebotnet33ts_256.a1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'botnet50ts_256.untrained': _cfg(
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_botnext26ts_256.c1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halonet_h1.untrained': _cfg(input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t.a1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'sehalonet33ts.ra2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'halonet50ts.a1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'eca_halonext26ts.c1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'lambda_resnet26t.c1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'lambda_resnet50ts.a1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet26rpt_256.c1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'haloregnetz_b.ra3_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
'lamhalobotnet50ts_256.a1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halo2botnet50ts_256.a1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
})
@register_model @register_model

View File

@ -15,39 +15,21 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP from ._manipulate import MATCH_PREV_GROUP
from ._registry import register_model from ._registry import register_model, generate_default_cfgs
__all__ = ['DenseNet'] __all__ = ['DenseNet']
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'features.conv0', 'classifier': 'classifier',
}
default_cfgs = {
'densenet121': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'),
'densenet121d': _cfg(url=''),
'densenetblur121d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'),
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
'densenet264': _cfg(url=''),
'densenet264d_iabn': _cfg(url=''),
'tv_densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'),
}
class DenseLayer(nn.Module): class DenseLayer(nn.Module):
def __init__( def __init__(
self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, self,
drop_rate=0., memory_efficient=False): num_input_features,
growth_rate,
bn_size,
norm_layer=BatchNormAct2d,
drop_rate=0.,
memory_efficient=False,
):
super(DenseLayer, self).__init__() super(DenseLayer, self).__init__()
self.add_module('norm1', norm_layer(num_input_features)), self.add_module('norm1', norm_layer(num_input_features)),
self.add_module('conv1', nn.Conv2d( self.add_module('conv1', nn.Conv2d(
@ -145,7 +127,13 @@ class DenseBlock(nn.ModuleDict):
class DenseTransition(nn.Sequential): class DenseTransition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None): def __init__(
self,
num_input_features,
num_output_features,
norm_layer=BatchNormAct2d,
aa_layer=None,
):
super(DenseTransition, self).__init__() super(DenseTransition, self).__init__()
self.add_module('norm', norm_layer(num_input_features)) self.add_module('norm', norm_layer(num_input_features))
self.add_module('conv', nn.Conv2d( self.add_module('conv', nn.Conv2d(
@ -324,9 +312,35 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
kwargs['growth_rate'] = growth_rate kwargs['growth_rate'] = growth_rate
kwargs['block_config'] = block_config kwargs['block_config'] = block_config
return build_model_with_cfg( return build_model_with_cfg(
DenseNet, variant, pretrained, DenseNet,
variant,
pretrained,
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
**kwargs) **kwargs,
)
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'features.conv0', 'classifier': 'classifier',
}
default_cfgs = {
'densenet121.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'),
'densenet121d': _cfg(url=''),
'densenetblur121d.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'),
'densenet169.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
'densenet201.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
'densenet161.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
'densenet264.untrained': _cfg(url=''),
'densenet121.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'),
}
@register_model @register_model
@ -400,22 +414,3 @@ def densenet264(pretrained=False, **kwargs):
'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs) 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
return model return model
@register_model
def densenet264d_iabn(pretrained=False, **kwargs):
r"""Densenet-264 model with deep stem and Inplace-ABN
"""
model = _create_densenet(
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
norm_layer='iabn', act_layer='leaky_relu', pretrained=pretrained, **kwargs)
return model
@register_model
def tv_densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model with original Torchvision weights, from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _create_densenet(
'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
return model

View File

@ -17,39 +17,11 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._registry import register_model from ._registry import register_model, generate_default_cfgs
__all__ = ['DPN'] __all__ = ['DPN']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'dpn48b': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68b': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn92': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
'dpn98': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
'dpn131': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'),
'dpn107': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth')
}
class CatBnAct(nn.Module): class CatBnAct(nn.Module):
def __init__(self, in_chs, norm_layer=BatchNormAct2d): def __init__(self, in_chs, norm_layer=BatchNormAct2d):
super(CatBnAct, self).__init__() super(CatBnAct, self).__init__()
@ -310,9 +282,42 @@ class DPN(nn.Module):
def _create_dpn(variant, pretrained=False, **kwargs): def _create_dpn(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
DPN, variant, pretrained, DPN,
variant,
pretrained,
feature_cfg=dict(feature_concat=True, flatten_sequential=True), feature_cfg=dict(feature_concat=True, flatten_sequential=True),
**kwargs) **kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
**kwargs
}
default_cfgs = generate_default_cfgs({
'dpn48b.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68b.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68b.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'),
'dpn92.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
'dpn98.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
'dpn131.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'),
'dpn107.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth')
})
@register_model @register_model

View File

@ -7,32 +7,12 @@ from ._builder import build_model_with_cfg
from ._builder import pretrained_cfg_for_features from ._builder import pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels from ._efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
from ._registry import register_model from ._registry import register_model, generate_default_cfgs
from .mobilenetv3 import MobileNetV3, MobileNetV3Features from .mobilenetv3 import MobileNetV3, MobileNetV3Features
__all__ = [] # model_registry will add each entrypoint fn to this __all__ = [] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'hardcorenas_a': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_a_green_38ms_75_9-31dc7186.pth'),
'hardcorenas_b': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_b_green_40ms_76_5-32d91ff2.pth'),
'hardcorenas_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_c_green_44ms_77_1-631a0983.pth'),
'hardcorenas_d': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_d_green_50ms_77_4-998d9d7a.pth'),
'hardcorenas_e': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_e_green_55ms_77_9-482886a3.pth'),
'hardcorenas_f': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_f_green_60ms_78_1-14b9e780.pth'),
}
def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
"""Creates a hardcorenas model """Creates a hardcorenas model
@ -60,15 +40,44 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool') kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool')
model_cls = MobileNetV3Features model_cls = MobileNetV3Features
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, model_cls,
variant,
pretrained,
pretrained_strict=not features_only, pretrained_strict=not features_only,
kwargs_filter=kwargs_filter, kwargs_filter=kwargs_filter,
**model_kwargs) **model_kwargs,
)
if features_only: if features_only:
model.default_cfg = pretrained_cfg_for_features(model.default_cfg) model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
return model return model
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = generate_default_cfgs({
'hardcorenas_a.green_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_a_green_38ms_75_9-31dc7186.pth'),
'hardcorenas_b.green_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_b_green_40ms_76_5-32d91ff2.pth'),
'hardcorenas_c.green_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_c_green_44ms_77_1-631a0983.pth'),
'hardcorenas_d.green_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_d_green_50ms_77_4-998d9d7a.pth'),
'hardcorenas_e.green_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_e_green_55ms_77_9-482886a3.pth'),
'hardcorenas_f.green_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_f_green_60ms_78_1-14b9e780.pth'),
})
@register_model @register_model
def hardcorenas_a(pretrained=False, **kwargs): def hardcorenas_a(pretrained=False, **kwargs):
""" hardcorenas_A """ """ hardcorenas_A """

View File

@ -6,104 +6,62 @@ Original model: https://github.com/mrT23/TResNet
""" """
from collections import OrderedDict from collections import OrderedDict
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\
ConvNormActAa, ConvNormAct, DropPath
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._registry import register_model from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
__all__ = ['TResNet'] # model_registry will add each entrypoint fn to this __all__ = ['TResNet'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': (0., 0., 0.), 'std': (1., 1., 1.),
'first_conv': 'body.conv1.0', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
'tresnet_m': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_1k_miil_83_1-d236afcb.pth'),
'tresnet_m_miil_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_miil_in21k-901b6ed4.pth', num_classes=11221),
'tresnet_l': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'),
'tresnet_xl': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
'tresnet_m_448': _cfg(
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
'tresnet_l_448': _cfg(
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
'tresnet_xl_448': _cfg(
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth'),
'tresnet_v2_l': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_v2_83_9-f36e4445.pth'),
}
def IABN2Float(module: nn.Module) -> nn.Module:
"""If `module` is IABN don't use half precision."""
if isinstance(module, InplaceAbn):
module.float()
for child in module.children():
IABN2Float(child)
return module
def conv2d_iabn(ni, nf, stride, kernel_size=3, groups=1, act_layer="leaky_relu", act_param=1e-2):
return nn.Sequential(
nn.Conv2d(
ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False),
InplaceAbn(nf, act_layer=act_layer, act_param=act_param)
)
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None): def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
use_se=True,
aa_layer=None,
drop_path_rate=0.
):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
if stride == 1:
self.conv1 = conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3)
else:
if aa_layer is None:
self.conv1 = conv2d_iabn(inplanes, planes, stride=2, act_param=1e-3)
else:
self.conv1 = nn.Sequential(
conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3),
aa_layer(channels=planes, filt_size=3, stride=2))
self.conv2 = conv2d_iabn(planes, planes, stride=1, act_layer="identity")
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
if stride == 1:
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer)
else:
self.conv1 = ConvNormActAa(
inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None)
self.act = nn.ReLU(inplace=True)
rd_chs = max(planes * self.expansion // 4, 64) rd_chs = max(planes * self.expansion // 4, 64)
self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def forward(self, x): def forward(self, x):
if self.downsample is not None: if self.downsample is not None:
shortcut = self.downsample(x) shortcut = self.downsample(x)
else: else:
shortcut = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.conv2(out) out = self.conv2(out)
if self.se is not None: if self.se is not None:
out = self.se(out) out = self.se(out)
out = self.drop_path(out) + shortcut
out = out + shortcut out = self.act(out)
out = self.relu(out)
return out return out
@ -111,47 +69,51 @@ class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__( def __init__(
self, inplanes, planes, stride=1, downsample=None, use_se=True, self,
act_layer="leaky_relu", aa_layer=None): inplanes,
planes,
stride=1,
downsample=None,
use_se=True,
act_layer=None,
aa_layer=None,
drop_path_rate=0.,
):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
self.conv1 = conv2d_iabn( self.downsample = downsample
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3) self.stride = stride
act_layer = act_layer or partial(nn.LeakyReLU, negative_slope=1e-3)
self.conv1 = ConvNormAct(
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
if stride == 1: if stride == 1:
self.conv2 = conv2d_iabn( self.conv2 = ConvNormAct(
planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3) planes, planes, kernel_size=3, stride=1, act_layer=act_layer)
else: else:
if aa_layer is None: self.conv2 = ConvNormActAa(
self.conv2 = conv2d_iabn( planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, act_param=1e-3)
else:
self.conv2 = nn.Sequential(
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
aa_layer(channels=planes, filt_size=3, stride=2))
reduction_chs = max(planes * self.expansion // 8, 64) reduction_chs = max(planes * self.expansion // 8, 64)
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
self.conv3 = conv2d_iabn( self.conv3 = ConvNormAct(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act = nn.ReLU(inplace=True) self.act = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x): def forward(self, x):
if self.downsample is not None: if self.downsample is not None:
shortcut = self.downsample(x) shortcut = self.downsample(x)
else: else:
shortcut = x shortcut = x
out = self.conv1(x) out = self.conv1(x)
out = self.conv2(out) out = self.conv2(out)
if self.se is not None: if self.se is not None:
out = self.se(out) out = self.se(out)
out = self.conv3(out) out = self.conv3(out)
out = out + shortcut # no inplace out = self.drop_path(out) + shortcut
out = self.act(out) out = self.act(out)
return out return out
@ -165,12 +127,15 @@ class TResNet(nn.Module):
v2=False, v2=False,
global_pool='fast', global_pool='fast',
drop_rate=0., drop_rate=0.,
drop_path_rate=0.,
): ):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False
super(TResNet, self).__init__() super(TResNet, self).__init__()
aa_layer = BlurPool2d aa_layer = BlurPool2d
act_layer = nn.LeakyReLU
# TResnet stages # TResnet stages
self.inplanes = int(64 * width_factor) self.inplanes = int(64 * width_factor)
@ -179,24 +144,30 @@ class TResNet(nn.Module):
self.inplanes = self.inplanes // 8 * 8 self.inplanes = self.inplanes // 8 * 8
self.planes = self.planes // 8 * 8 self.planes = self.planes // 8 * 8
conv1 = conv2d_iabn(in_chans * 16, self.planes, stride=1, kernel_size=3) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer)
layer1 = self._make_layer( layer1 = self._make_layer(
Bottleneck if v2 else BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer) Bottleneck if v2 else BasicBlock,
self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0])
layer2 = self._make_layer( layer2 = self._make_layer(
Bottleneck if v2 else BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer) Bottleneck if v2 else BasicBlock,
self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1])
layer3 = self._make_layer( layer3 = self._make_layer(
Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer) Bottleneck,
self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2])
layer4 = self._make_layer( layer4 = self._make_layer(
Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer) Bottleneck,
self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3])
# body # body
self.body = nn.Sequential(OrderedDict([ self.body = nn.Sequential(OrderedDict([
('SpaceToDepth', SpaceToDepthModule()), ('s2d', SpaceToDepth()),
('conv1', conv1), ('conv1', conv1),
('layer1', layer1), ('layer1', layer1),
('layer2', layer2), ('layer2', layer2),
('layer3', layer3), ('layer3', layer3),
('layer4', layer4)])) ('layer4', layer4),
]))
self.feature_info = [ self.feature_info = [
dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
@ -214,37 +185,39 @@ class TResNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InplaceAbn): if isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 1) m.weight.data.normal_(0, 0.01)
nn.init.constant_(m.bias, 0)
# residual connections special initialization # residual connections special initialization
for m in self.modules(): for m in self.modules():
if isinstance(m, BasicBlock): if isinstance(m, BasicBlock):
m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero nn.init.zeros_(m.conv2.bn.weight)
if isinstance(m, Bottleneck): if isinstance(m, Bottleneck):
m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero nn.init.zeros_(m.conv3.bn.weight)
if isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None): def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None, drop_path_rate=0.):
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
layers = [] layers = []
if stride == 2: if stride == 2:
# avg pooling before 1x1 conv # avg pooling before 1x1 conv
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
layers += [conv2d_iabn( layers += [ConvNormAct(
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, act_layer="identity")] self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)]
downsample = nn.Sequential(*layers) downsample = nn.Sequential(*layers)
layers = [] layers = []
for i in range(blocks):
layers.append(block( layers.append(block(
self.inplanes, planes, stride, downsample, use_se=use_se, aa_layer=aa_layer)) self.inplanes,
planes,
stride=stride if i == 0 else 1,
downsample=downsample if i == 0 else None,
use_se=use_se,
aa_layer=aa_layer,
drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate,
))
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer))
return nn.Sequential(*layers) return nn.Sequential(*layers)
@torch.jit.ignore @torch.jit.ignore
@ -254,18 +227,28 @@ class TResNet(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported' self.grad_checkpointing = enable
@torch.jit.ignore @torch.jit.ignore
def get_classifier(self): def get_classifier(self):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='fast'): def reset_classifier(self, num_classes, global_pool=None):
self.head = ClassifierHead( self.head.reset(num_classes, pool_type=global_pool)
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x): def forward_features(self, x):
return self.body(x) if self.grad_checkpointing and not torch.jit.is_scripting():
x = self.body.s2d(x)
x = self.body.conv1(x)
x = checkpoint_seq([
self.body.layer1,
self.body.layer2,
self.body.layer3,
self.body.layer4],
x, flatten=True)
else:
x = self.body(x)
return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
@ -276,11 +259,74 @@ class TResNet(nn.Module):
return x return x
def checkpoint_filter_fn(state_dict, model):
if 'body.conv1.conv.weight' in state_dict:
return state_dict
import re
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
out_dict = {}
for k, v in state_dict.items():
k = re.sub(r'conv(\d+)\.0.0', lambda x: f'conv{int(x.group(1))}.conv', k)
k = re.sub(r'conv(\d+)\.0.1', lambda x: f'conv{int(x.group(1))}.bn', k)
k = re.sub(r'conv(\d+)\.0', lambda x: f'conv{int(x.group(1))}.conv', k)
k = re.sub(r'conv(\d+)\.1', lambda x: f'conv{int(x.group(1))}.bn', k)
k = k.replace('downsample.1.0', 'downsample.1.conv')
k = k.replace('downsample.1.1', 'downsample.1.bn')
if k.endswith('bn.weight'):
# convert weight from inplace_abn to batchnorm
v = v.abs().add(1e-5)
out_dict[k] = v
return out_dict
def _create_tresnet(variant, pretrained=False, **kwargs): def _create_tresnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
TResNet, variant, pretrained, TResNet,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
**kwargs) **kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': (0., 0., 0.), 'std': (1., 1., 1.),
'first_conv': 'body.conv1.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = generate_default_cfgs({
'tresnet_m.miil_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_1k_miil_83_1-d236afcb.pth'),
'tresnet_m.miil_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_miil_in21k-901b6ed4.pth',
num_classes=11221),
'tresnet_m.miil_in1k': _cfg(
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_80_8-dbc13962.pth'),
'tresnet_l.miil_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'),
'tresnet_xl.miil_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
'tresnet_m.miil_in1k_448': _cfg(
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
'tresnet_l.miil_in1k_448': _cfg(
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
'tresnet_xl.miil_in1k_448': _cfg(
input_size=(3, 448, 448), pool_size=(14, 14),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth'),
'tresnet_v2_l.miil_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_v2_83_9-f36e4445.pth'),
})
@register_model @register_model
@ -289,24 +335,12 @@ def tresnet_m(pretrained=False, **kwargs):
return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs) return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs)
@register_model
def tresnet_m_miil_in21k(pretrained=False, **kwargs):
model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
return _create_tresnet('tresnet_m_miil_in21k', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def tresnet_l(pretrained=False, **kwargs): def tresnet_l(pretrained=False, **kwargs):
model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs)
return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs) return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs)
@register_model
def tresnet_v2_l(pretrained=False, **kwargs):
model_kwargs = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True, **kwargs)
return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def tresnet_xl(pretrained=False, **kwargs): def tresnet_xl(pretrained=False, **kwargs):
model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs)
@ -314,18 +348,14 @@ def tresnet_xl(pretrained=False, **kwargs):
@register_model @register_model
def tresnet_m_448(pretrained=False, **kwargs): def tresnet_v2_l(pretrained=False, **kwargs):
model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) model_kwargs = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True, **kwargs)
return _create_tresnet('tresnet_m_448', pretrained=pretrained, **model_kwargs) return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **model_kwargs)
@register_model register_model_deprecations(__name__, {
def tresnet_l_448(pretrained=False, **kwargs): 'tresnet_m_miil_in21k': 'tresnet_m.miil_in21k',
model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) 'tresnet_m_448': 'tresnet_m.miil_in1k_448',
return _create_tresnet('tresnet_l_448', pretrained=pretrained, **model_kwargs) 'tresnet_l_448': 'tresnet_l.miil_in1k_448',
'tresnet_xl_448': 'tresnet_xl.miil_in1k_448',
})
@register_model
def tresnet_xl_448(pretrained=False, **kwargs):
model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs)
return _create_tresnet('tresnet_xl_448', pretrained=pretrained, **model_kwargs)

View File

@ -18,152 +18,14 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \ from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
create_attn, create_norm_act_layer, get_norm_act_layer create_attn, create_norm_act_layer
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model from ._registry import register_model, generate_default_cfgs
__all__ = ['VovNet'] # model_registry will add each entrypoint fn to this __all__ = ['VovNet'] # model_registry will add each entrypoint fn to this
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
model_cfgs = dict(
vovnet39a=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=False,
depthwise=False,
attn='',
),
vovnet57a=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 4, 3],
residual=False,
depthwise=False,
attn='',
),
ese_vovnet19b_slim_dw=dict(
stem_chs=[64, 64, 64],
stage_conv_chs=[64, 80, 96, 112],
stage_out_chs=[112, 256, 384, 512],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=True,
attn='ese',
),
ese_vovnet19b_dw=dict(
stem_chs=[64, 64, 64],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=True,
attn='ese',
),
ese_vovnet19b_slim=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[64, 80, 96, 112],
stage_out_chs=[112, 256, 384, 512],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet19b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet39b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet57b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 4, 3],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet99b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 3, 9, 3],
residual=True,
depthwise=False,
attn='ese',
),
eca_vovnet39b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=True,
depthwise=False,
attn='eca',
),
)
model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
model_cfgs['ese_vovnet99b_iabn'] = model_cfgs['ese_vovnet99b']
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
}
default_cfgs = dict(
vovnet39a=_cfg(url=''),
vovnet57a=_cfg(url=''),
ese_vovnet19b_slim_dw=_cfg(url=''),
ese_vovnet19b_dw=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'),
ese_vovnet19b_slim=_cfg(url=''),
ese_vovnet39b=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'),
ese_vovnet57b=_cfg(url=''),
ese_vovnet99b=_cfg(url=''),
eca_vovnet39b=_cfg(url=''),
ese_vovnet39b_evos=_cfg(url=''),
ese_vovnet99b_iabn=_cfg(url=''),
)
class SequentialAppendList(nn.Sequential): class SequentialAppendList(nn.Sequential):
def __init__(self, *args): def __init__(self, *args):
super(SequentialAppendList, self).__init__(*args) super(SequentialAppendList, self).__init__(*args)
@ -405,12 +267,151 @@ class VovNet(nn.Module):
return x return x
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
model_cfgs = dict(
vovnet39a=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=False,
depthwise=False,
attn='',
),
vovnet57a=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 4, 3],
residual=False,
depthwise=False,
attn='',
),
ese_vovnet19b_slim_dw=dict(
stem_chs=[64, 64, 64],
stage_conv_chs=[64, 80, 96, 112],
stage_out_chs=[112, 256, 384, 512],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=True,
attn='ese',
),
ese_vovnet19b_dw=dict(
stem_chs=[64, 64, 64],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=True,
attn='ese',
),
ese_vovnet19b_slim=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[64, 80, 96, 112],
stage_out_chs=[112, 256, 384, 512],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet19b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet39b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet57b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 4, 3],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet99b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 3, 9, 3],
residual=True,
depthwise=False,
attn='ese',
),
eca_vovnet39b=dict(
stem_chs=[64, 64, 128],
stage_conv_chs=[128, 160, 192, 224],
stage_out_chs=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=True,
depthwise=False,
attn='eca',
),
)
model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
def _create_vovnet(variant, pretrained=False, **kwargs): def _create_vovnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
VovNet, variant, pretrained, VovNet,
variant,
pretrained,
model_cfg=model_cfgs[variant], model_cfg=model_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**kwargs) **kwargs,
)
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
}
default_cfgs = generate_default_cfgs({
'vovnet39a.untrained': _cfg(url=''),
'vovnet57a.untrained': _cfg(url=''),
'ese_vovnet19b_slim_dw.untrained': _cfg(url=''),
'ese_vovnet19b_dw.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'),
'ese_vovnet19b_slim.untrained': _cfg(url=''),
'ese_vovnet39b.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'),
'ese_vovnet57b.untrained': _cfg(url=''),
'ese_vovnet99b.untrained': _cfg(url=''),
'eca_vovnet39b.untrained': _cfg(url=''),
'ese_vovnet39b_evos.untrained': _cfg(url=''),
})
@register_model @register_model
@ -465,10 +466,3 @@ def ese_vovnet39b_evos(pretrained=False, **kwargs):
def norm_act_fn(num_features, **nkwargs): def norm_act_fn(num_features, **nkwargs):
return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs) return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs)
return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
@register_model
def ese_vovnet99b_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
return _create_vovnet(
'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs)

View File

@ -15,47 +15,23 @@ from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act
from timm.layers.helpers import to_3tuple from timm.layers.helpers import to_3tuple
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model from ._registry import register_model, generate_default_cfgs
__all__ = ['XceptionAligned'] __all__ = ['XceptionAligned']
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
'crop_pct': 0.903, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
xception41=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
xception65=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65_ra3-1447db8d.pth',
crop_pct=0.94,
),
xception71=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
xception41p=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception41p_ra3-33195bc8.pth',
crop_pct=0.94,
),
xception65p=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65p_ra3-3c6114e4.pth',
crop_pct=0.94,
),
)
class SeparableConv2d(nn.Module): class SeparableConv2d(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='', self,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=1,
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
):
super(SeparableConv2d, self).__init__() super(SeparableConv2d, self).__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilation = dilation self.dilation = dilation
@ -84,8 +60,17 @@ class SeparableConv2d(nn.Module):
class PreSeparableConv2d(nn.Module): class PreSeparableConv2d(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='', self,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, first_act=True): in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=1,
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
first_act=True,
):
super(PreSeparableConv2d, self).__init__() super(PreSeparableConv2d, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -109,8 +94,17 @@ class PreSeparableConv2d(nn.Module):
class XceptionModule(nn.Module): class XceptionModule(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, stride=1, dilation=1, pad_type='', self,
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None): in_chs,
out_chs,
stride=1,
dilation=1,
pad_type='',
start_with_relu=True,
no_skip=False,
act_layer=nn.ReLU,
norm_layer=None,
):
super(XceptionModule, self).__init__() super(XceptionModule, self).__init__()
out_chs = to_3tuple(out_chs) out_chs = to_3tuple(out_chs)
self.in_channels = in_chs self.in_channels = in_chs
@ -144,8 +138,16 @@ class XceptionModule(nn.Module):
class PreXceptionModule(nn.Module): class PreXceptionModule(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, stride=1, dilation=1, pad_type='', self,
no_skip=False, act_layer=nn.ReLU, norm_layer=None): in_chs,
out_chs,
stride=1,
dilation=1,
pad_type='',
no_skip=False,
act_layer=nn.ReLU,
norm_layer=None,
):
super(PreXceptionModule, self).__init__() super(PreXceptionModule, self).__init__()
out_chs = to_3tuple(out_chs) out_chs = to_3tuple(out_chs)
self.in_channels = in_chs self.in_channels = in_chs
@ -160,8 +162,16 @@ class PreXceptionModule(nn.Module):
self.stack = nn.Sequential() self.stack = nn.Sequential()
for i in range(3): for i in range(3):
self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d( self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d(
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, in_chs,
act_layer=act_layer, norm_layer=norm_layer, first_act=i > 0)) out_chs[i],
3,
stride=stride if i == 2 else 1,
dilation=dilation,
padding=pad_type,
act_layer=act_layer,
norm_layer=norm_layer,
first_act=i > 0,
))
in_chs = out_chs[i] in_chs = out_chs[i]
def forward(self, x): def forward(self, x):
@ -178,8 +188,17 @@ class XceptionAligned(nn.Module):
""" """
def __init__( def __init__(
self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False, self,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'): block_cfg,
num_classes=1000,
in_chans=3,
output_stride=32,
preact=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_rate=0.,
global_pool='avg',
):
super(XceptionAligned, self).__init__() super(XceptionAligned, self).__init__()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
self.num_classes = num_classes self.num_classes = num_classes
@ -216,7 +235,11 @@ class XceptionAligned(nn.Module):
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
self.act = act_layer(inplace=True) if preact else nn.Identity() self.act = act_layer(inplace=True) if preact else nn.Identity()
self.head = ClassifierHead( self.head = ClassifierHead(
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=self.num_features,
num_classes=num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
@ -234,7 +257,7 @@ class XceptionAligned(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
@ -256,9 +279,47 @@ class XceptionAligned(nn.Module):
def _xception(variant, pretrained=False, **kwargs): def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
XceptionAligned, variant, pretrained, XceptionAligned,
variant,
pretrained,
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), feature_cfg=dict(flatten_sequential=True, feature_cls='hook'),
**kwargs) **kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
'crop_pct': 0.903, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = generate_default_cfgs({
'xception65.ra3_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65_ra3-1447db8d.pth',
crop_pct=0.94,
),
'xception41.tf_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
'xception65.tf_in1k': _cfg(
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
'xception71.tf_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
'xception41p.ra3_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception41p_ra3-33195bc8.pth',
crop_pct=0.94,
),
'xception65p.ra3_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65p_ra3-3c6114e4.pth',
crop_pct=0.94,
),
})
@register_model @register_model