Merge pull request #475 from rwightman/pre-release/0.4.5
Prep for PyPi release. Tweak NFNet, ResNetV2, RexNet feature extraction, use pre-act features. Also, add act args to RexNet #202pull/489/head v0.4.5
commit
5b28ef4100
|
@ -17,8 +17,8 @@ jobs:
|
|||
matrix:
|
||||
os: [ubuntu-latest, macOS-latest]
|
||||
python: ['3.8']
|
||||
torch: ['1.7.0']
|
||||
torchvision: ['0.8.1']
|
||||
torch: ['1.8.0']
|
||||
torchvision: ['0.9.0']
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
|
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
## What's New
|
||||
|
||||
### March 7, 2021
|
||||
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
|
||||
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
|
||||
|
||||
### Feb 18, 2021
|
||||
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
|
||||
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
|
||||
|
|
|
@ -21,7 +21,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
|||
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
|
||||
EXCLUDE_FILTERS = [
|
||||
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
|
||||
'*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS
|
||||
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS
|
||||
else:
|
||||
EXCLUDE_FILTERS = NON_STD_FILTERS
|
||||
|
||||
|
|
|
@ -264,9 +264,11 @@ class CrossStage(nn.Module):
|
|||
if self.conv_down is not None:
|
||||
x = self.conv_down(x)
|
||||
x = self.conv_exp(x)
|
||||
xs, xb = x.chunk(2, dim=1)
|
||||
split = x.shape[1] // 2
|
||||
xs, xb = x[:, :split], x[:, split:]
|
||||
xb = self.blocks(xb)
|
||||
out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1))
|
||||
xb = self.conv_transition_b(xb).contiguous()
|
||||
out = self.conv_transition(torch.cat([xs, xb], dim=1))
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ except ImportError:
|
|||
def inplace_abn(x, weight, bias, running_mean, running_var,
|
||||
training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
|
||||
raise ImportError(
|
||||
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
||||
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
|
||||
|
||||
def inplace_abn_sync(**kwargs):
|
||||
inplace_abn(**kwargs)
|
||||
|
|
|
@ -101,11 +101,12 @@ default_cfgs = dict(
|
|||
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
|
||||
|
||||
nfnet_l0a=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
|
||||
nfnet_l0b=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
|
||||
nfnet_l0c=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0c-ad1045c2.pth',
|
||||
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
|
||||
|
||||
nf_regnet_b0=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
|
||||
|
@ -376,9 +377,9 @@ class NormFreeBlock(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
|
||||
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None, preact_feature=True):
|
||||
stem_stride = 2
|
||||
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
|
||||
stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv')
|
||||
stem = OrderedDict()
|
||||
assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
|
||||
if 'deep' in stem_type:
|
||||
|
@ -388,14 +389,14 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
|
|||
stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs)
|
||||
strides = (2, 1, 1, 2)
|
||||
stem_stride = 4
|
||||
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act4')
|
||||
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3')
|
||||
else:
|
||||
if 'tiered' in stem_type:
|
||||
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py
|
||||
else:
|
||||
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
|
||||
strides = (2, 1, 1)
|
||||
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
|
||||
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2')
|
||||
last_idx = len(stem_chs) - 1
|
||||
for i, (c, s) in enumerate(zip(stem_chs, strides)):
|
||||
stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
|
||||
|
@ -477,7 +478,7 @@ class NormFreeNet(nn.Module):
|
|||
self.stem, stem_stride, stem_feat = create_stem(
|
||||
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
|
||||
|
||||
self.feature_info = [stem_feat] if stem_stride == 4 else []
|
||||
self.feature_info = [stem_feat]
|
||||
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
||||
prev_chs = stem_chs
|
||||
net_stride = stem_stride
|
||||
|
@ -486,8 +487,6 @@ class NormFreeNet(nn.Module):
|
|||
stages = []
|
||||
for stage_idx, stage_depth in enumerate(cfg.depths):
|
||||
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
|
||||
if stride == 2:
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')]
|
||||
if net_stride >= output_stride and stride > 1:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
|
@ -522,6 +521,7 @@ class NormFreeNet(nn.Module):
|
|||
expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance
|
||||
first_dilation = dilation
|
||||
prev_chs = out_chs
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')]
|
||||
stages += [nn.Sequential(*blocks)]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
|
@ -529,11 +529,11 @@ class NormFreeNet(nn.Module):
|
|||
# The paper NFRegNet models have an EfficientNet-like final head convolution.
|
||||
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
|
||||
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
|
||||
self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv')
|
||||
else:
|
||||
self.num_features = prev_chs
|
||||
self.final_conv = nn.Identity()
|
||||
self.final_act = act_layer(inplace=cfg.num_features > 0)
|
||||
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
|
||||
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
|
@ -572,10 +572,6 @@ class NormFreeNet(nn.Module):
|
|||
def _create_normfreenet(variant, pretrained=False, **kwargs):
|
||||
model_cfg = model_cfgs[variant]
|
||||
feature_cfg = dict(flatten_sequential=True)
|
||||
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
|
||||
if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type:
|
||||
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems
|
||||
|
||||
return build_model_with_cfg(
|
||||
NormFreeNet, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
|
|
|
@ -323,8 +323,8 @@ class ResNetV2(nn.Module):
|
|||
self.feature_info = []
|
||||
stem_chs = make_div(stem_chs * wf)
|
||||
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
# NOTE no, reduction 2 feature if preact
|
||||
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
|
||||
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
|
||||
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
|
||||
|
||||
prev_chs = stem_chs
|
||||
curr_stride = 4
|
||||
|
@ -343,10 +343,7 @@ class ResNetV2(nn.Module):
|
|||
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
|
||||
prev_chs = out_chs
|
||||
curr_stride *= stride
|
||||
feat_name = f'stages.{stage_idx}'
|
||||
if preact:
|
||||
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
|
||||
self.stages.add_module(str(stage_idx), stage)
|
||||
|
||||
self.num_features = prev_chs
|
||||
|
@ -414,13 +411,7 @@ class ResNetV2(nn.Module):
|
|||
|
||||
|
||||
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||
# FIXME feature map extraction is not setup properly for pre-activation mode right now
|
||||
preact = kwargs.get('preact', True)
|
||||
feature_cfg = dict(flatten_sequential=True)
|
||||
if preact:
|
||||
feature_cfg['feature_cls'] = 'hook'
|
||||
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact
|
||||
|
||||
return build_model_with_cfg(
|
||||
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
|
||||
feature_cfg=feature_cfg, **kwargs)
|
||||
|
|
|
@ -71,7 +71,8 @@ class SEWithNorm(nn.Module):
|
|||
|
||||
|
||||
class LinearBottleneck(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, drop_path=None):
|
||||
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1,
|
||||
act_layer='swish', dw_act_layer='relu6', drop_path=None):
|
||||
super(LinearBottleneck, self).__init__()
|
||||
self.use_shortcut = stride == 1 and in_chs <= out_chs
|
||||
self.in_channels = in_chs
|
||||
|
@ -79,14 +80,14 @@ class LinearBottleneck(nn.Module):
|
|||
|
||||
if exp_ratio != 1.:
|
||||
dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
|
||||
self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer="swish")
|
||||
self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer)
|
||||
else:
|
||||
dw_chs = in_chs
|
||||
self.conv_exp = None
|
||||
|
||||
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
|
||||
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
|
||||
self.act_dw = nn.ReLU6()
|
||||
self.act_dw = create_act_layer(dw_act_layer)
|
||||
|
||||
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
|
||||
self.drop_path = drop_path
|
||||
|
@ -131,8 +132,7 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se
|
|||
|
||||
|
||||
def _build_blocks(
|
||||
block_cfg, prev_chs, width_mult, ch_div=1, drop_path_rate=0., feature_location='bottleneck'):
|
||||
feat_exp = feature_location == 'expansion'
|
||||
block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.):
|
||||
feat_chs = [prev_chs]
|
||||
feature_info = []
|
||||
curr_stride = 2
|
||||
|
@ -141,29 +141,25 @@ def _build_blocks(
|
|||
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
|
||||
if stride > 1:
|
||||
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
|
||||
if block_idx > 0 and feat_exp:
|
||||
fname += '.act_dw'
|
||||
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
|
||||
curr_stride *= stride
|
||||
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
|
||||
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
|
||||
features.append(LinearBottleneck(
|
||||
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
|
||||
ch_div=ch_div, drop_path=drop_path))
|
||||
ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path))
|
||||
prev_chs = chs
|
||||
feat_chs += [features[-1].feat_channels(feat_exp)]
|
||||
feat_chs += [features[-1].feat_channels()]
|
||||
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
|
||||
feature_info += [dict(
|
||||
num_chs=pen_chs if feat_exp else feat_chs[-1], reduction=curr_stride,
|
||||
module=f'features.{len(features) - int(not feat_exp)}')]
|
||||
features.append(ConvBnAct(prev_chs, pen_chs, act_layer="swish"))
|
||||
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')]
|
||||
features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer))
|
||||
return features, feature_info
|
||||
|
||||
|
||||
class ReXNetV1(nn.Module):
|
||||
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
|
||||
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
|
||||
ch_div=1, drop_rate=0.2, drop_path_rate=0., feature_location='bottleneck'):
|
||||
ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.):
|
||||
super(ReXNetV1, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.num_classes = num_classes
|
||||
|
@ -171,11 +167,11 @@ class ReXNetV1(nn.Module):
|
|||
assert output_stride == 32 # FIXME support dilation
|
||||
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
|
||||
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
|
||||
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish')
|
||||
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
|
||||
|
||||
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
|
||||
features, self.feature_info = _build_blocks(
|
||||
block_cfg, stem_chs, width_mult, ch_div, drop_path_rate, feature_location)
|
||||
block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate)
|
||||
self.num_features = features[-1].out_channels
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
|
@ -202,8 +198,6 @@ class ReXNetV1(nn.Module):
|
|||
|
||||
def _create_rexnet(variant, pretrained, **kwargs):
|
||||
feature_cfg = dict(flatten_sequential=True)
|
||||
if kwargs.get('feature_location', '') == 'expansion':
|
||||
feature_cfg['feature_cls'] = 'hook'
|
||||
return build_model_with_cfg(
|
||||
ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs)
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.4.4'
|
||||
__version__ = '0.4.5'
|
||||
|
|
Loading…
Reference in New Issue