mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More model feature extraction support, start to deprecate senet.py, dilations added to regnet, add proper aligned xception
This commit is contained in:
parent
7729f40dca
commit
a66df5fb91
@ -1,30 +1,31 @@
|
||||
from .inception_v4 import *
|
||||
from .inception_resnet_v2 import *
|
||||
from .densenet import *
|
||||
from .resnet import *
|
||||
from .dla import *
|
||||
from .dpn import *
|
||||
from .senet import *
|
||||
from .xception import *
|
||||
from .nasnet import *
|
||||
from .pnasnet import *
|
||||
from .selecsls import *
|
||||
from .efficientnet import *
|
||||
from .mobilenetv3 import *
|
||||
from .inception_v3 import *
|
||||
from .gluon_resnet import *
|
||||
from .gluon_xception import *
|
||||
from .res2net import *
|
||||
from .dla import *
|
||||
from .hrnet import *
|
||||
from .inception_resnet_v2 import *
|
||||
from .inception_v3 import *
|
||||
from .inception_v4 import *
|
||||
from .mobilenetv3 import *
|
||||
from .nasnet import *
|
||||
from .pnasnet import *
|
||||
from .regnet import *
|
||||
from .res2net import *
|
||||
from .resnest import *
|
||||
from .resnet import *
|
||||
from .selecsls import *
|
||||
from .senet import *
|
||||
from .sknet import *
|
||||
from .tresnet import *
|
||||
from .resnest import *
|
||||
from .regnet import *
|
||||
from .vovnet import *
|
||||
from .xception import *
|
||||
from .xception_aligned import *
|
||||
|
||||
from .registry import *
|
||||
from .factory import create_model
|
||||
from .helpers import load_checkpoint, resume_checkpoint
|
||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||
from .registry import *
|
||||
|
@ -74,6 +74,9 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
||||
|
||||
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
||||
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
if in_chans == 1:
|
||||
conv1_name = cfg['first_conv']
|
||||
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
||||
@ -95,9 +98,6 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
||||
del state_dict[classifier_name + '.bias']
|
||||
strict = False
|
||||
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
|
@ -223,11 +223,12 @@ class Block8(nn.Module):
|
||||
|
||||
|
||||
class InceptionResnetV2(nn.Module):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
|
||||
super(InceptionResnetV2, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 1536
|
||||
assert output_stride == 32
|
||||
|
||||
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
||||
@ -340,16 +341,16 @@ class InceptionResnetV2(nn.Module):
|
||||
|
||||
|
||||
def _inception_resnet_v2(variant, pretrained=False, **kwargs):
|
||||
load_strict, features, out_indices = True, False, None
|
||||
features, out_indices = False, None
|
||||
if kwargs.pop('features_only', False):
|
||||
load_strict, features, out_indices = False, True, kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
kwargs.pop('num_classes', 0)
|
||||
features = True
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
model = InceptionResnetV2(**kwargs)
|
||||
model.default_cfg = default_cfgs[variant]
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices)
|
||||
return model
|
||||
|
@ -400,14 +400,15 @@ class ReductionCell1(nn.Module):
|
||||
class NASNetALarge(nn.Module):
|
||||
"""NASNetALarge (6 @ 4032) """
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, num_features=4032, channel_multiplier=2,
|
||||
drop_rate=0., global_pool='avg', pad_type='same'):
|
||||
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2,
|
||||
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
|
||||
super(NASNetALarge, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.stem_size = stem_size
|
||||
self.num_features = num_features
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.drop_rate = drop_rate
|
||||
assert output_stride == 32
|
||||
|
||||
channels = self.num_features // 24
|
||||
# 24 is default value for the architecture
|
||||
|
@ -236,11 +236,12 @@ class Cell(CellBase):
|
||||
|
||||
|
||||
class PNASNet5Large(nn.Module):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg', padding=''):
|
||||
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0.5, global_pool='avg', padding=''):
|
||||
super(PNASNet5Large, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 4320
|
||||
self.drop_rate = drop_rate
|
||||
self.num_features = 4320
|
||||
assert output_stride == 32
|
||||
|
||||
self.conv_0 = ConvBnAct(
|
||||
in_chans, 96, kernel_size=3, stride=2, padding=0,
|
||||
|
@ -12,15 +12,15 @@ Weights from original impl have been modified
|
||||
* remap names to match the ones here
|
||||
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .features import FeatureNet
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d, AvgPool2dSame, ConvBnAct, SEModule
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
def _mcfg(**kwargs):
|
||||
@ -128,18 +128,17 @@ class Bottleneck(nn.Module):
|
||||
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, stride=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25,
|
||||
dilation=1, first_dilation=None, downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||
aa_layer=None, drop_block=None, drop_path=None):
|
||||
def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25,
|
||||
downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
|
||||
drop_block=None, drop_path=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
bottleneck_chs = int(round(out_chs * bottleneck_ratio))
|
||||
groups = bottleneck_chs // group_width
|
||||
first_dilation = first_dilation or dilation
|
||||
|
||||
cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
|
||||
self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs)
|
||||
self.conv2 = ConvBnAct(
|
||||
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=first_dilation,
|
||||
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation,
|
||||
groups=groups, **cargs)
|
||||
if se_ratio:
|
||||
se_channels = int(round(in_chs * se_ratio))
|
||||
@ -172,16 +171,16 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
def downsample_conv(
|
||||
in_chs, out_chs, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
|
||||
in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None):
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
|
||||
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
|
||||
dilation = dilation if kernel_size > 1 else 1
|
||||
return ConvBnAct(
|
||||
in_chs, out_chs, kernel_size, stride=stride, dilation=first_dilation, norm_layer=norm_layer, act_layer=None)
|
||||
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None)
|
||||
|
||||
|
||||
def downsample_avg(
|
||||
in_chs, out_chs, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
|
||||
in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None):
|
||||
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
avg_stride = stride if dilation == 1 else 1
|
||||
@ -196,21 +195,24 @@ def downsample_avg(
|
||||
class RegStage(nn.Module):
|
||||
"""Stage (sequence of blocks w/ the same output shape)."""
|
||||
|
||||
def __init__(self, in_chs, out_chs, stride, depth, block_fn, bottle_ratio, group_width, se_ratio):
|
||||
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
|
||||
block_fn=Bottleneck, se_ratio=0.):
|
||||
super(RegStage, self).__init__()
|
||||
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
for i in range(depth):
|
||||
block_stride = stride if i == 0 else 1
|
||||
block_in_chs = in_chs if i == 0 else out_chs
|
||||
block_dilation = first_dilation if i == 0 else dilation
|
||||
if (block_in_chs != out_chs) or (block_stride != 1):
|
||||
proj_block = downsample_conv(block_in_chs, out_chs, 1, stride)
|
||||
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
|
||||
else:
|
||||
proj_block = None
|
||||
|
||||
name = "b{}".format(i + 1)
|
||||
self.add_module(
|
||||
name, block_fn(
|
||||
block_in_chs, out_chs, block_stride, bottle_ratio, group_width, se_ratio,
|
||||
block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio,
|
||||
downsample=proj_block, **block_kwargs)
|
||||
)
|
||||
|
||||
@ -247,26 +249,30 @@ class RegNet(nn.Module):
|
||||
Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.,
|
||||
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
|
||||
zero_init_last_bn=True):
|
||||
super().__init__()
|
||||
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
assert output_stride in (8, 16, 32)
|
||||
|
||||
# Construct the stem
|
||||
stem_width = cfg['stem_width']
|
||||
self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2)
|
||||
self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')]
|
||||
|
||||
# Construct the stages
|
||||
block_fn = Bottleneck
|
||||
prev_width = stem_width
|
||||
stage_params = self._get_stage_params(cfg)
|
||||
curr_stride = 2
|
||||
stage_params = self._get_stage_params(cfg, output_stride=output_stride)
|
||||
se_ratio = cfg['se_ratio']
|
||||
for i, (d, w, s, br, gw) in enumerate(stage_params):
|
||||
self.add_module(
|
||||
"s{}".format(i + 1), RegStage(prev_width, w, s, d, block_fn, br, gw, se_ratio))
|
||||
prev_width = w
|
||||
for i, stage_args in enumerate(stage_params):
|
||||
stage_name = "s{}".format(i + 1)
|
||||
self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio))
|
||||
prev_width = stage_args['out_chs']
|
||||
curr_stride *= stage_args['stride']
|
||||
self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
|
||||
|
||||
# Construct the head
|
||||
self.num_features = prev_width
|
||||
@ -287,7 +293,7 @@ class RegNet(nn.Module):
|
||||
if hasattr(m, 'zero_init_last_bn'):
|
||||
m.zero_init_last_bn()
|
||||
|
||||
def _get_stage_params(self, cfg, stride=2):
|
||||
def _get_stage_params(self, cfg, default_stride=2, output_stride=32):
|
||||
# Generate RegNet ws per block
|
||||
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
|
||||
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
||||
@ -298,12 +304,26 @@ class RegNet(nn.Module):
|
||||
# Use the same group width, bottleneck mult and stride for each stage
|
||||
stage_groups = [cfg['group_w'] for _ in range(num_stages)]
|
||||
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
|
||||
stage_strides = [stride for _ in range(num_stages)]
|
||||
# FIXME add dilation / output_stride support
|
||||
stage_strides = []
|
||||
stage_dilations = []
|
||||
total_stride = 2
|
||||
dilation = 1
|
||||
for _ in range(num_stages):
|
||||
if total_stride >= output_stride:
|
||||
dilation *= default_stride
|
||||
stride = 1
|
||||
else:
|
||||
stride = default_stride
|
||||
total_stride *= stride
|
||||
stage_strides.append(stride)
|
||||
stage_dilations.append(dilation)
|
||||
|
||||
# Adjust the compatibility of ws and gws
|
||||
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
|
||||
stage_params = list(zip(stage_depths, stage_widths, stage_strides, stage_bottle_ratios, stage_groups))
|
||||
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width']
|
||||
stage_params = [
|
||||
dict(zip(param_names, params)) for params in
|
||||
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups)]
|
||||
return stage_params
|
||||
|
||||
def get_classifier(self):
|
||||
@ -324,20 +344,20 @@ class RegNet(nn.Module):
|
||||
|
||||
|
||||
def _regnet(variant, pretrained, **kwargs):
|
||||
load_strict = True
|
||||
model_class = RegNet
|
||||
features = False
|
||||
out_indices = None
|
||||
if kwargs.pop('features_only', False):
|
||||
assert False, 'Not Implemented' # TODO
|
||||
load_strict = False
|
||||
kwargs.pop('num_classes', 0)
|
||||
features = True
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
model_cfg = model_cfgs[variant]
|
||||
default_cfg = default_cfgs[variant]
|
||||
model = model_class(model_cfg, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = RegNet(model_cfg, **kwargs)
|
||||
model.default_cfg = default_cfgs[variant]
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices=out_indices)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -33,6 +33,7 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# ResNet and Wide ResNet
|
||||
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
|
||||
'resnet34': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
|
||||
@ -54,6 +55,8 @@ default_cfgs = {
|
||||
'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
|
||||
'wide_resnet50_2': _cfg(url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth'),
|
||||
'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'),
|
||||
|
||||
# ResNeXt
|
||||
'resnext50_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth',
|
||||
interpolation='bicubic'),
|
||||
@ -64,10 +67,17 @@ default_cfgs = {
|
||||
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
|
||||
'resnext101_64x4d': _cfg(url=''),
|
||||
'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'),
|
||||
|
||||
# ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags
|
||||
# from https://github.com/facebookresearch/WSL-Images
|
||||
# Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
|
||||
'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'),
|
||||
'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'),
|
||||
'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'),
|
||||
'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'),
|
||||
|
||||
# Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models
|
||||
# Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
|
||||
'ssl_resnet18': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth'),
|
||||
'ssl_resnet50': _cfg(
|
||||
@ -80,6 +90,9 @@ default_cfgs = {
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth'),
|
||||
'ssl_resnext101_32x16d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth'),
|
||||
|
||||
# Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models
|
||||
# Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
|
||||
'swsl_resnet18': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth'),
|
||||
'swsl_resnet50': _cfg(
|
||||
@ -92,6 +105,31 @@ default_cfgs = {
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth'),
|
||||
'swsl_resnext101_32x16d': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'),
|
||||
|
||||
# Squeeze-Excitation ResNets, to eventually replace the models in senet.py
|
||||
'seresnet18': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'seresnet34': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'seresnet50': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'seresnet50tn': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'seresnet101': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'seresnet152': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
|
||||
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
|
||||
'seresnext26_32x4d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'seresnext26d_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
|
||||
interpolation='bicubic'),
|
||||
@ -101,9 +139,19 @@ default_cfgs = {
|
||||
'seresnext26tn_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
|
||||
interpolation='bicubic'),
|
||||
'ecaresnext26tn_32x4d': _cfg(
|
||||
'seresnext50_32x4d': _cfg(
|
||||
interpolation='bicubic'),
|
||||
'seresnext101_32x4d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'seresnext101_32x8d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'senet154': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
|
||||
# Efficient Channel Attention ResNets
|
||||
'ecaresnet18': _cfg(),
|
||||
'ecaresnet50': _cfg(),
|
||||
'ecaresnetlight': _cfg(
|
||||
@ -121,6 +169,16 @@ default_cfgs = {
|
||||
'ecaresnet101d_pruned': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
||||
interpolation='bicubic'),
|
||||
|
||||
# Efficient Channel Attention ResNeXts
|
||||
'ecaresnext26tn_32x4d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'ecaresnext50_32x4d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
|
||||
# ResNets with anti-aliasing blur pool
|
||||
'resnetblur18': _cfg(
|
||||
interpolation='bicubic'),
|
||||
'resnetblur50': _cfg(
|
||||
@ -278,6 +336,14 @@ class Bottleneck(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def setup_drop_block(drop_block_rate=0.):
|
||||
return [
|
||||
None,
|
||||
None,
|
||||
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
|
||||
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
|
||||
|
||||
|
||||
def downsample_conv(
|
||||
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
@ -386,6 +452,7 @@ class ResNet(nn.Module):
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
|
||||
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
|
||||
block_args = block_args or dict()
|
||||
assert output_stride in (8, 16, 32)
|
||||
self.num_classes = num_classes
|
||||
deep_stem = 'deep' in stem_type
|
||||
self.inplanes = stem_width * 2 if deep_stem else 64
|
||||
@ -393,7 +460,6 @@ class ResNet(nn.Module):
|
||||
self.base_width = base_width
|
||||
self.drop_rate = drop_rate
|
||||
self.expansion = block.expansion
|
||||
self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')]
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
# Stem
|
||||
@ -414,6 +480,8 @@ class ResNet(nn.Module):
|
||||
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')]
|
||||
|
||||
# Stem Pooling
|
||||
if aa_layer is not None:
|
||||
self.maxpool = nn.Sequential(*[
|
||||
@ -424,32 +492,26 @@ class ResNet(nn.Module):
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# Feature Blocks
|
||||
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
||||
if output_stride == 16:
|
||||
strides[3] = 1
|
||||
dilations[3] = 2
|
||||
elif output_stride == 8:
|
||||
strides[2:4] = [1, 1]
|
||||
dilations[2:4] = [2, 4]
|
||||
else:
|
||||
assert output_stride == 32
|
||||
channels = [64, 128, 256, 512]
|
||||
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
||||
db = [
|
||||
None, None,
|
||||
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
|
||||
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
|
||||
layer_args = list(zip(channels, layers, strides, dilations))
|
||||
db = setup_drop_block(drop_block_rate)
|
||||
layer_kwargs = dict(
|
||||
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
|
||||
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
||||
current_stride = 4
|
||||
total_stride = 4
|
||||
dilation = 1
|
||||
for i in range(4):
|
||||
layer_name = f'layer{i + 1}'
|
||||
stride = 2 if i > 0 else 1
|
||||
if total_stride >= output_stride:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
else:
|
||||
total_stride *= stride
|
||||
self.add_module(layer_name, self._make_layer(
|
||||
block, *layer_args[i], drop_block=db[i], **layer_kwargs))
|
||||
current_stride *= strides[i]
|
||||
block, channels[i], layers[i], stride, dilation, drop_block=db[i], **layer_kwargs))
|
||||
self.feature_info.append(dict(
|
||||
num_chs=self.inplanes, reduction=current_stride, module=layer_name))
|
||||
num_chs=self.inplanes, reduction=total_stride, module=layer_name))
|
||||
|
||||
# Head (Pooling and Classifier)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
@ -872,55 +934,6 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs):
|
||||
return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26d_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNeXt-26-D model.
|
||||
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
|
||||
combination of deep stem and avg_pool in downsample.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26d_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26t_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNet-26-T model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
|
||||
in the deep stem.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26t_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26tn_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNeXt-26-TN model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnext26tn_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs an ECA-ResNeXt-26-TN model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||
this model replaces SE module with the ECA module
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet18(pretrained=False, **kwargs):
|
||||
""" Constructs an ECA-ResNet-18 model.
|
||||
@ -989,6 +1002,19 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs):
|
||||
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnext26tn_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs an ECA-ResNeXt-26-TN model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||
this model replaces SE module with the ECA module
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetblur18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model with blur anti-aliasing
|
||||
@ -1003,3 +1029,123 @@ def resnetblur50(pretrained=False, **kwargs):
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
|
||||
return _create_resnet('resnetblur50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet18(pretrained=False, **kwargs):
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnet18', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet34(pretrained=False, **kwargs):
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnet34', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet50(pretrained=False, **kwargs):
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnet50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet50tn(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
|
||||
block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnet50tn', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet101(pretrained=False, **kwargs):
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnet101', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet152(pretrained=False, **kwargs):
|
||||
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnet152', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26_32x4d(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4,
|
||||
block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26d_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNeXt-26-D model.`
|
||||
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
|
||||
combination of deep stem and avg_pool in downsample.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26d_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26t_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNet-26-T model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
|
||||
in the deep stem.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26t_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26tn_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNeXt-26-TN model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext50_32x4d(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext50_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext101_32x4d(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
|
||||
block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext101_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext101_32x8d(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
|
||||
block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext101_32x8d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def senet154(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
|
||||
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('senet154', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def eseresnet50(pretrained=False, **kwargs):
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='ese'), **kwargs)
|
||||
return _create_resnet('seresnet50', pretrained, **model_args)
|
||||
|
@ -7,6 +7,9 @@ Original model: https://github.com/hujie-frank/SENet
|
||||
|
||||
ResNet code gently borrowed from
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
||||
|
||||
FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate
|
||||
support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here.
|
||||
"""
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
@ -397,7 +400,7 @@ class SENet(nn.Module):
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet18']
|
||||
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
@ -410,7 +413,7 @@ def seresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet34']
|
||||
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
@ -423,7 +426,7 @@ def seresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet50']
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
@ -436,7 +439,7 @@ def seresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet101']
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
@ -449,7 +452,7 @@ def seresnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet152']
|
||||
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
@ -462,7 +465,7 @@ def seresnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['senet154']
|
||||
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
@ -473,7 +476,7 @@ def senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext26_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
@ -486,7 +489,7 @@ def seresnext26_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext50_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
@ -499,7 +502,7 @@ def seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def legacy_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext101_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
|
@ -275,13 +275,14 @@ class ClassifierHead(nn.Module):
|
||||
class VovNet(nn.Module):
|
||||
|
||||
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
|
||||
norm_layer=BatchNormAct2d):
|
||||
output_stride=32, norm_layer=BatchNormAct2d):
|
||||
""" VovNet (v2)
|
||||
"""
|
||||
super(VovNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
assert stem_stride in (4, 2)
|
||||
assert output_stride == 32 # FIXME support dilation
|
||||
|
||||
stem_chs = cfg["stem_chs"]
|
||||
stage_conv_chs = cfg["stage_conv_chs"]
|
||||
@ -349,7 +350,6 @@ def _vovnet(variant, pretrained=False, **kwargs):
|
||||
out_indices = None
|
||||
if kwargs.pop('features_only', False):
|
||||
features = True
|
||||
kwargs.pop('num_classes', 0)
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
model_cfg = model_cfgs[variant]
|
||||
model = VovNet(model_cfg, **kwargs)
|
||||
@ -412,10 +412,11 @@ def eca_vovnet39b(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def ese_vovnet39b_evos(pretrained=False, **kwargs):
|
||||
def norm_act_fn(num_features, **kwargs):
|
||||
return create_norm_act('EvoNormSample', num_features, jit=False, **kwargs)
|
||||
def norm_act_fn(num_features, **nkwargs):
|
||||
return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs)
|
||||
return _vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ese_vovnet99b_iabn(pretrained=False, **kwargs):
|
||||
norm_layer = get_norm_act_layer('iabn')
|
||||
|
278
timm/models/xception_aligned.py
Normal file
278
timm/models/xception_aligned.py
Normal file
@ -0,0 +1,278 @@
|
||||
"""Pytorch impl of Aligned Xception
|
||||
|
||||
This is a correct impl of Aligned Xception (Deeplab) models compatible with TF definition.
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .features import FeatureNet
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['XceptionAligned']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
|
||||
'crop_pct': 0.903, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
xception41=_cfg(url=''),
|
||||
xception65=_cfg(url=''),
|
||||
xception71=_cfg(url=''),
|
||||
)
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='',
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
|
||||
# depthwise convolution
|
||||
self.conv_dw = create_conv2d(
|
||||
inplanes, inplanes, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, depthwise=True)
|
||||
self.bn_dw = norm_layer(inplanes, **norm_kwargs)
|
||||
if act_layer is not None:
|
||||
self.act_dw = act_layer(inplace=True)
|
||||
else:
|
||||
self.act_dw = None
|
||||
|
||||
# pointwise convolution
|
||||
self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1)
|
||||
self.bn_pw = norm_layer(planes, **norm_kwargs)
|
||||
if act_layer is not None:
|
||||
self.act_pw = act_layer(inplace=True)
|
||||
else:
|
||||
self.act_pw = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn_dw(x)
|
||||
if self.act_dw is not None:
|
||||
x = self.act_dw(x)
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn_pw(x)
|
||||
if self.act_pw is not None:
|
||||
x = self.act_pw(x)
|
||||
return x
|
||||
|
||||
|
||||
class XceptionModule(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
|
||||
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None):
|
||||
super(XceptionModule, self).__init__()
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
if isinstance(out_chs, (list, tuple)):
|
||||
assert len(out_chs) == 3
|
||||
else:
|
||||
out_chs = (out_chs,) * 3
|
||||
self.in_channels = in_chs
|
||||
self.out_channels = out_chs[-1]
|
||||
self.no_skip = no_skip
|
||||
if not no_skip and (self.out_channels != self.in_channels or stride != 1):
|
||||
self.shortcut = ConvBnAct(
|
||||
in_chs, self.out_channels, 1, stride=stride,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, act_layer=None)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
separable_act_layer = None if start_with_relu else act_layer
|
||||
self.stack = nn.Sequential()
|
||||
for i in range(3):
|
||||
if start_with_relu:
|
||||
self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0))
|
||||
self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
|
||||
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
|
||||
act_layer=separable_act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs))
|
||||
in_chs = out_chs[i]
|
||||
|
||||
def forward(self, x):
|
||||
skip = x
|
||||
x = self.stack(x)
|
||||
if self.shortcut is not None:
|
||||
skip = self.shortcut(skip)
|
||||
if not self.no_skip:
|
||||
x = x + skip
|
||||
return x
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Head."""
|
||||
|
||||
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
if num_classes > 0:
|
||||
self.fc = nn.Linear(in_chs, num_classes, bias=True)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.global_pool(x).flatten(1)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class XceptionAligned(nn.Module):
|
||||
"""Modified Aligned Xception
|
||||
"""
|
||||
|
||||
def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
||||
super(XceptionAligned, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
assert output_stride in (8, 16, 32)
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
|
||||
xtra_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.stem = nn.Sequential(*[
|
||||
ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **xtra_args),
|
||||
ConvBnAct(32, 64, kernel_size=3, stride=1, **xtra_args)
|
||||
])
|
||||
curr_dilation = 1
|
||||
curr_stride = 2
|
||||
self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem.1')]
|
||||
|
||||
self.blocks = nn.Sequential()
|
||||
for i, b in enumerate(block_cfg):
|
||||
feature_extract = False
|
||||
b['dilation'] = curr_dilation
|
||||
if b['stride'] > 1:
|
||||
feature_extract = True
|
||||
next_stride = curr_stride * b['stride']
|
||||
if next_stride > output_stride:
|
||||
curr_dilation *= b['stride']
|
||||
b['stride'] = 1
|
||||
else:
|
||||
curr_stride = next_stride
|
||||
self.blocks.add_module(str(i), XceptionModule(**b, **xtra_args))
|
||||
self.num_features = self.blocks[-1].out_channels
|
||||
if feature_extract:
|
||||
self.feature_info += [dict(
|
||||
num_chs=self.num_features, reduction=curr_stride, module=f'blocks.{i}.stack.act2')]
|
||||
|
||||
self.feature_info += [dict(
|
||||
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
|
||||
|
||||
self.head = ClassifierHead(
|
||||
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _xception(variant, pretrained=False, **kwargs):
|
||||
features = False
|
||||
out_indices = None
|
||||
if kwargs.pop('features_only', False):
|
||||
features = True
|
||||
kwargs.pop('num_classes', 0)
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
model = XceptionAligned(**kwargs)
|
||||
model.default_cfg = default_cfgs[variant]
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
strict=not features)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def xception41(pretrained=False, **kwargs):
|
||||
""" Modified Aligned Xception-41
|
||||
"""
|
||||
block_cfg = [
|
||||
# entry flow
|
||||
dict(in_chs=64, out_chs=128, stride=2),
|
||||
dict(in_chs=128, out_chs=256, stride=2),
|
||||
dict(in_chs=256, out_chs=728, stride=2),
|
||||
# middle flow
|
||||
*([dict(in_chs=728, out_chs=728, stride=1)] * 8),
|
||||
# exit flow
|
||||
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
|
||||
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
|
||||
]
|
||||
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
|
||||
return _xception('xception41', pretrained=pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def xception65(pretrained=False, **kwargs):
|
||||
""" Modified Aligned Xception-65
|
||||
"""
|
||||
block_cfg = [
|
||||
# entry flow
|
||||
dict(in_chs=64, out_chs=128, stride=2),
|
||||
dict(in_chs=128, out_chs=256, stride=2),
|
||||
dict(in_chs=256, out_chs=728, stride=2),
|
||||
# middle flow
|
||||
*([dict(in_chs=728, out_chs=728, stride=1)] * 16),
|
||||
# exit flow
|
||||
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
|
||||
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
|
||||
]
|
||||
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
|
||||
return _xception('xception65', pretrained=pretrained, **model_args)
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def xception71(pretrained=False, **kwargs):
|
||||
""" Modified Aligned Xception-71
|
||||
"""
|
||||
block_cfg = [
|
||||
# entry flow
|
||||
dict(in_chs=64, out_chs=128, stride=2),
|
||||
dict(in_chs=128, out_chs=256, stride=1),
|
||||
dict(in_chs=256, out_chs=256, stride=2),
|
||||
dict(in_chs=256, out_chs=728, stride=1),
|
||||
dict(in_chs=728, out_chs=728, stride=2),
|
||||
# middle flow
|
||||
*([dict(in_chs=728, out_chs=728, stride=1)] * 16),
|
||||
# exit flow
|
||||
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
|
||||
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
|
||||
]
|
||||
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
|
||||
return _xception('xception71', pretrained=pretrained, **model_args)
|
Loading…
x
Reference in New Issue
Block a user