mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix regression in models with 1001 class pretrained weights. Improve batchnorm arg and BatchNormAct layer handling in several models.
This commit is contained in:
parent
aaa715b1e9
commit
9811e229f7
@ -83,7 +83,6 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
cfg = model.default_cfg
|
||||
|
||||
classifier = cfg['classifier']
|
||||
first_conv = cfg['first_conv']
|
||||
pool_size = cfg['pool_size']
|
||||
input_size = model.default_cfg['input_size']
|
||||
|
||||
@ -111,9 +110,16 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
# FIXME mobilenetv3 forward_features vs removed pooling differ
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
|
||||
# check classifier and first convolution names match those in default_cfg
|
||||
# check classifier name matches default_cfg
|
||||
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
|
||||
assert first_conv + ".weight" in state_dict.keys(), f'{first_conv} not in model params'
|
||||
|
||||
# check first conv(s) names match default_cfg
|
||||
first_conv = cfg['first_conv']
|
||||
if isinstance(first_conv, str):
|
||||
first_conv = (first_conv,)
|
||||
assert isinstance(first_conv, (tuple, list))
|
||||
for fc in first_conv:
|
||||
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
|
||||
|
||||
|
||||
if 'GITHUB_ACTIONS' not in os.environ:
|
||||
|
@ -7,6 +7,7 @@ This implementation is compatible with the pretrained weights from cypw's MXNet
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
@ -173,12 +174,14 @@ class DPN(nn.Module):
|
||||
self.drop_rate = drop_rate
|
||||
self.b = b
|
||||
assert output_stride == 32 # FIXME look into dilation support
|
||||
norm_layer = partial(BatchNormAct2d, eps=.001)
|
||||
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False)
|
||||
bw_factor = 1 if small else 4
|
||||
blocks = OrderedDict()
|
||||
|
||||
# conv1
|
||||
blocks['conv1_1'] = ConvBnAct(
|
||||
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_kwargs=dict(eps=.001))
|
||||
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer)
|
||||
blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]
|
||||
|
||||
@ -226,8 +229,7 @@ class DPN(nn.Module):
|
||||
in_chs += inc
|
||||
self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')]
|
||||
|
||||
def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False)
|
||||
blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm)
|
||||
blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer)
|
||||
|
||||
self.num_features = in_chs
|
||||
self.features = nn.Sequential(blocks)
|
||||
|
@ -42,10 +42,8 @@ for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd w
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(self, inplanes, planes, kernel_size=3, stride=1,
|
||||
dilation=1, bias=False, norm_layer=None, norm_kwargs=None):
|
||||
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
|
||||
@ -54,7 +52,7 @@ class SeparableConv2d(nn.Module):
|
||||
self.conv_dw = nn.Conv2d(
|
||||
inplanes, inplanes, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
|
||||
self.bn = norm_layer(num_features=inplanes, **norm_kwargs)
|
||||
self.bn = norm_layer(num_features=inplanes)
|
||||
# pointwise convolution
|
||||
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
|
||||
|
||||
@ -66,10 +64,8 @@ class SeparableConv2d(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True,
|
||||
norm_layer=None, norm_kwargs=None, ):
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
|
||||
super(Block, self).__init__()
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
if isinstance(planes, (list, tuple)):
|
||||
assert len(planes) == 3
|
||||
else:
|
||||
@ -80,7 +76,7 @@ class Block(nn.Module):
|
||||
self.skip = nn.Sequential()
|
||||
self.skip.add_module('conv1', nn.Conv2d(
|
||||
inplanes, outplanes, 1, stride=stride, bias=False)),
|
||||
self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs))
|
||||
self.skip.add_module('bn1', norm_layer(num_features=outplanes))
|
||||
else:
|
||||
self.skip = None
|
||||
|
||||
@ -88,9 +84,8 @@ class Block(nn.Module):
|
||||
for i in range(3):
|
||||
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
|
||||
rep['conv%d' % (i + 1)] = SeparableConv2d(
|
||||
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs)
|
||||
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
|
||||
rep['bn%d' % (i + 1)] = norm_layer(planes[i])
|
||||
inplanes = planes[i]
|
||||
|
||||
if not start_with_relu:
|
||||
@ -115,74 +110,63 @@ class Xception65(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
||||
norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
||||
drop_rate=0., global_pool='avg'):
|
||||
super(Xception65, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
if output_stride == 32:
|
||||
entry_block3_stride = 2
|
||||
exit_block20_stride = 2
|
||||
middle_block_dilation = 1
|
||||
exit_block_dilations = (1, 1)
|
||||
middle_dilation = 1
|
||||
exit_dilation = (1, 1)
|
||||
elif output_stride == 16:
|
||||
entry_block3_stride = 2
|
||||
exit_block20_stride = 1
|
||||
middle_block_dilation = 1
|
||||
exit_block_dilations = (1, 2)
|
||||
middle_dilation = 1
|
||||
exit_dilation = (1, 2)
|
||||
elif output_stride == 8:
|
||||
entry_block3_stride = 1
|
||||
exit_block20_stride = 1
|
||||
middle_block_dilation = 2
|
||||
exit_block_dilations = (2, 4)
|
||||
middle_dilation = 2
|
||||
exit_dilation = (2, 4)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Entry flow
|
||||
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = norm_layer(num_features=32, **norm_kwargs)
|
||||
self.bn1 = norm_layer(num_features=32)
|
||||
self.act1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = norm_layer(num_features=64)
|
||||
self.act2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.block1 = Block(
|
||||
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
|
||||
self.block1_act = nn.ReLU(inplace=True)
|
||||
self.block2 = Block(
|
||||
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.block3 = Block(
|
||||
256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
|
||||
self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
|
||||
|
||||
# Middle flow
|
||||
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
||||
728, 728, stride=1, dilation=middle_block_dilation,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)]))
|
||||
728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
|
||||
|
||||
# Exit flow
|
||||
self.block20 = Block(
|
||||
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
|
||||
self.block20_act = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv3 = SeparableConv2d(
|
||||
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
|
||||
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
||||
self.bn3 = norm_layer(num_features=1536)
|
||||
self.act3 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv4 = SeparableConv2d(
|
||||
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
|
||||
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
||||
self.bn4 = norm_layer(num_features=1536)
|
||||
self.act4 = nn.ReLU(inplace=True)
|
||||
|
||||
self.num_features = 2048
|
||||
self.conv5 = SeparableConv2d(
|
||||
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
|
||||
1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
||||
self.bn5 = norm_layer(num_features=self.num_features)
|
||||
self.act5 = nn.ReLU(inplace=True)
|
||||
self.feature_info = [
|
||||
dict(num_chs=64, reduction=2, module='act2'),
|
||||
|
@ -148,6 +148,31 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
|
||||
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
|
||||
|
||||
|
||||
def adapt_input_conv(in_chans, conv_weight):
|
||||
conv_type = conv_weight.dtype
|
||||
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
||||
O, I, J, K = conv_weight.shape
|
||||
if in_chans == 1:
|
||||
if I > 3:
|
||||
assert conv_weight.shape[1] % 3 == 0
|
||||
# For models with space2depth stems
|
||||
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
||||
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
||||
else:
|
||||
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
||||
elif in_chans != 3:
|
||||
if I != 3:
|
||||
raise NotImplementedError('Weight format not supported by conversion.')
|
||||
else:
|
||||
# NOTE this strategy should be better than random init, but there could be other combinations of
|
||||
# the original RGB input layer weights that'd work better for specific cases.
|
||||
repeat = int(math.ceil(in_chans / 3))
|
||||
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
||||
conv_weight *= (3 / float(in_chans))
|
||||
conv_weight = conv_weight.to(conv_type)
|
||||
return conv_weight
|
||||
|
||||
|
||||
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
@ -159,56 +184,35 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
if in_chans == 1:
|
||||
conv1_name = cfg['first_conv']
|
||||
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
|
||||
conv1_weight = state_dict[conv1_name + '.weight']
|
||||
# Some weights are in torch.half, ensure it's float for sum on CPU
|
||||
conv1_type = conv1_weight.dtype
|
||||
conv1_weight = conv1_weight.float()
|
||||
O, I, J, K = conv1_weight.shape
|
||||
if I > 3:
|
||||
assert conv1_weight.shape[1] % 3 == 0
|
||||
# For models with space2depth stems
|
||||
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
|
||||
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
|
||||
else:
|
||||
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
|
||||
conv1_weight = conv1_weight.to(conv1_type)
|
||||
state_dict[conv1_name + '.weight'] = conv1_weight
|
||||
elif in_chans != 3:
|
||||
conv1_name = cfg['first_conv']
|
||||
conv1_weight = state_dict[conv1_name + '.weight']
|
||||
conv1_type = conv1_weight.dtype
|
||||
conv1_weight = conv1_weight.float()
|
||||
O, I, J, K = conv1_weight.shape
|
||||
if I != 3:
|
||||
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
|
||||
del state_dict[conv1_name + '.weight']
|
||||
strict = False
|
||||
else:
|
||||
# NOTE this strategy should be better than random init, but there could be other combinations of
|
||||
# the original RGB input layer weights that'd work better for specific cases.
|
||||
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
|
||||
repeat = int(math.ceil(in_chans / 3))
|
||||
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
||||
conv1_weight *= (3 / float(in_chans))
|
||||
conv1_weight = conv1_weight.to(conv1_type)
|
||||
state_dict[conv1_name + '.weight'] = conv1_weight
|
||||
input_convs = cfg.get('first_conv', None)
|
||||
if input_convs is not None:
|
||||
if isinstance(input_convs, str):
|
||||
input_convs = (input_convs,)
|
||||
for input_conv_name in input_convs:
|
||||
weight_name = input_conv_name + '.weight'
|
||||
try:
|
||||
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
|
||||
_logger.info(
|
||||
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
|
||||
except NotImplementedError as e:
|
||||
del state_dict[weight_name]
|
||||
strict = False
|
||||
_logger.warning(
|
||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
||||
|
||||
classifier_name = cfg['classifier']
|
||||
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
||||
# FIXME this special case is problematic as number of pretrained weight sources increases
|
||||
# special case for imagenet trained models with extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
|
||||
elif num_classes != cfg['num_classes']:
|
||||
# completely discard fully connected for all other differences between pretrained and created model
|
||||
label_offset = cfg.get('label_offset', 0)
|
||||
if num_classes != cfg['num_classes']:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
del state_dict[classifier_name + '.weight']
|
||||
del state_dict[classifier_name + '.bias']
|
||||
strict = False
|
||||
elif label_offset > 0:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
@ -17,18 +17,20 @@ default_cfgs = {
|
||||
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
|
||||
'inception_resnet_v2': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
|
||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
|
||||
'ens_adv_inception_resnet_v2': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
|
||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
}
|
||||
}
|
||||
|
||||
@ -222,7 +224,7 @@ class Block8(nn.Module):
|
||||
|
||||
|
||||
class InceptionResnetV2(nn.Module):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
|
||||
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
|
||||
super(InceptionResnetV2, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.num_classes = num_classes
|
||||
|
@ -16,10 +16,11 @@ __all__ = ['InceptionV4']
|
||||
default_cfgs = {
|
||||
'inception_v4': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
|
||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
}
|
||||
}
|
||||
|
||||
@ -241,7 +242,7 @@ class InceptionC(nn.Module):
|
||||
|
||||
|
||||
class InceptionV4(nn.Module):
|
||||
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
|
||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
|
||||
super(InceptionV4, self).__init__()
|
||||
assert output_stride == 32
|
||||
self.drop_rate = drop_rate
|
||||
|
@ -12,7 +12,7 @@ from .conv_bn_act import ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import get_attn, create_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import create_norm_act, get_norm_act_layer
|
||||
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
|
@ -5,23 +5,23 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
from torch import nn as nn
|
||||
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import convert_norm_act_type
|
||||
from .create_norm_act import convert_norm_act
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True,
|
||||
drop_block=None, aa_layer=None):
|
||||
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None,
|
||||
drop_block=None):
|
||||
super(ConvBnAct, self).__init__()
|
||||
use_aa = aa_layer is not None
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=False)
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
|
||||
norm_act_layer = convert_norm_act(norm_layer, act_layer)
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
|
||||
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
|
||||
|
||||
@property
|
||||
|
@ -9,6 +9,8 @@ from .cbam import CbamModule, LightCbamModule
|
||||
|
||||
|
||||
def get_attn(attn_type):
|
||||
if isinstance(attn_type, torch.nn.Module):
|
||||
return attn_type
|
||||
module_cls = None
|
||||
if attn_type is not None:
|
||||
if isinstance(attn_type, str):
|
||||
|
@ -19,6 +19,7 @@ from .inplace_abn import InplaceAbn
|
||||
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
|
||||
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type
|
||||
|
||||
|
||||
def get_norm_act_layer(layer_class):
|
||||
layer_class = layer_class.replace('_', '').lower()
|
||||
if layer_class.startswith("batchnorm"):
|
||||
@ -47,16 +48,22 @@ def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwarg
|
||||
return layer_instance
|
||||
|
||||
|
||||
def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
|
||||
def convert_norm_act(norm_layer, act_layer):
|
||||
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||||
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
|
||||
norm_act_args = norm_kwargs.copy() if norm_kwargs else {}
|
||||
norm_act_kwargs = {}
|
||||
|
||||
# unbind partial fn, so args can be rebound later
|
||||
if isinstance(norm_layer, functools.partial):
|
||||
norm_act_kwargs.update(norm_layer.keywords)
|
||||
norm_layer = norm_layer.func
|
||||
|
||||
if isinstance(norm_layer, str):
|
||||
norm_act_layer = get_norm_act_layer(norm_layer)
|
||||
elif norm_layer in _NORM_ACT_TYPES:
|
||||
norm_act_layer = norm_layer
|
||||
elif isinstance(norm_layer, (types.FunctionType, functools.partial)):
|
||||
# assuming this is a lambda/fn/bound partial that creates norm_act layer
|
||||
elif isinstance(norm_layer, types.FunctionType):
|
||||
# if function type, must be a lambda/fn that creates a norm_act layer
|
||||
norm_act_layer = norm_layer
|
||||
else:
|
||||
type_name = norm_layer.__name__.lower()
|
||||
@ -66,9 +73,11 @@ def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
|
||||
norm_act_layer = GroupNormAct
|
||||
else:
|
||||
assert False, f"No equivalent norm_act layer for {type_name}"
|
||||
|
||||
if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
|
||||
# Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
|
||||
# pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
|
||||
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
|
||||
# It is intended that functions/partial does not trigger this, they should define act.
|
||||
norm_act_args.update(dict(act_layer=act_layer))
|
||||
return norm_act_layer, norm_act_args
|
||||
norm_act_kwargs.setdefault('act_layer', act_layer)
|
||||
if norm_act_kwargs:
|
||||
norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
|
||||
return norm_act_layer
|
||||
|
@ -24,7 +24,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = None
|
||||
self.act = nn.Identity()
|
||||
|
||||
def _forward_jit(self, x):
|
||||
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
|
||||
@ -62,8 +62,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
||||
x = self._forward_jit(x)
|
||||
else:
|
||||
x = self._forward_python(x)
|
||||
if self.act is not None:
|
||||
x = self.act(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -75,12 +74,12 @@ class GroupNormAct(nn.GroupNorm):
|
||||
if isinstance(act_layer, str):
|
||||
act_layer = get_act_layer(act_layer)
|
||||
if act_layer is not None and apply_act:
|
||||
self.act = act_layer(inplace=inplace)
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = None
|
||||
self.act = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
if self.act is not None:
|
||||
x = self.act(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
@ -8,17 +8,16 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
from torch import nn as nn
|
||||
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import convert_norm_act_type
|
||||
from .create_norm_act import convert_norm_act
|
||||
|
||||
|
||||
class SeparableConvBnAct(nn.Module):
|
||||
""" Separable Conv w/ trailing Norm and Activation
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
|
||||
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
act_layer=nn.ReLU, apply_act=True, drop_block=None):
|
||||
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
|
||||
apply_act=True, drop_block=None):
|
||||
super(SeparableConvBnAct, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
|
||||
self.conv_dw = create_conv2d(
|
||||
in_channels, int(in_channels * channel_multiplier), kernel_size,
|
||||
@ -27,8 +26,8 @@ class SeparableConvBnAct(nn.Module):
|
||||
self.conv_pw = create_conv2d(
|
||||
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
|
||||
|
||||
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
|
||||
norm_act_layer = convert_norm_act(norm_layer, act_layer)
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
|
@ -1,6 +1,9 @@
|
||||
""" NasNet-A (Large)
|
||||
nasnetalarge implementation grabbed from Cadene's pretrained models
|
||||
https://github.com/Cadene/pretrained-models.pytorch
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -20,9 +23,10 @@ default_cfgs = {
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1001,
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv0.conv',
|
||||
'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
}
|
||||
|
||||
@ -418,7 +422,7 @@ class NASNetALarge(nn.Module):
|
||||
|
||||
self.conv0 = ConvBnAct(
|
||||
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
|
||||
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
|
||||
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
|
||||
|
||||
self.cell_stem_0 = CellStem0(
|
||||
self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -26,9 +27,10 @@ default_cfgs = {
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1001,
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv_0.conv',
|
||||
'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
}
|
||||
|
||||
@ -234,7 +236,7 @@ class Cell(CellBase):
|
||||
|
||||
|
||||
class PNASNet5Large(nn.Module):
|
||||
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''):
|
||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''):
|
||||
super(PNASNet5Large, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
@ -243,7 +245,7 @@ class PNASNet5Large(nn.Module):
|
||||
|
||||
self.conv_0 = ConvBnAct(
|
||||
in_chans, 96, kernel_size=3, stride=2, padding=0,
|
||||
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
|
||||
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
|
||||
|
||||
self.cell_stem_0 = CellStem0(
|
||||
in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type)
|
||||
|
@ -5,7 +5,7 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -43,9 +43,8 @@ default_cfgs = dict(
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='',
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
|
||||
@ -53,7 +52,7 @@ class SeparableConv2d(nn.Module):
|
||||
self.conv_dw = create_conv2d(
|
||||
inplanes, inplanes, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, depthwise=True)
|
||||
self.bn_dw = norm_layer(inplanes, **norm_kwargs)
|
||||
self.bn_dw = norm_layer(inplanes)
|
||||
if act_layer is not None:
|
||||
self.act_dw = act_layer(inplace=True)
|
||||
else:
|
||||
@ -61,7 +60,7 @@ class SeparableConv2d(nn.Module):
|
||||
|
||||
# pointwise convolution
|
||||
self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1)
|
||||
self.bn_pw = norm_layer(planes, **norm_kwargs)
|
||||
self.bn_pw = norm_layer(planes)
|
||||
if act_layer is not None:
|
||||
self.act_pw = act_layer(inplace=True)
|
||||
else:
|
||||
@ -82,17 +81,15 @@ class SeparableConv2d(nn.Module):
|
||||
class XceptionModule(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
|
||||
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None):
|
||||
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None):
|
||||
super(XceptionModule, self).__init__()
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
out_chs = to_3tuple(out_chs)
|
||||
self.in_channels = in_chs
|
||||
self.out_channels = out_chs[-1]
|
||||
self.no_skip = no_skip
|
||||
if not no_skip and (self.out_channels != self.in_channels or stride != 1):
|
||||
self.shortcut = ConvBnAct(
|
||||
in_chs, self.out_channels, 1, stride=stride,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, act_layer=None)
|
||||
in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
@ -103,7 +100,7 @@ class XceptionModule(nn.Module):
|
||||
self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0))
|
||||
self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
|
||||
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
|
||||
act_layer=separable_act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs))
|
||||
act_layer=separable_act_layer, norm_layer=norm_layer))
|
||||
in_chs = out_chs[i]
|
||||
|
||||
def forward(self, x):
|
||||
@ -121,14 +118,13 @@ class XceptionAligned(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
|
||||
super(XceptionAligned, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
assert output_stride in (8, 16, 32)
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
|
||||
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer)
|
||||
self.stem = nn.Sequential(*[
|
||||
ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
|
||||
ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args)
|
||||
@ -196,7 +192,7 @@ def xception41(pretrained=False, **kwargs):
|
||||
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
|
||||
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
|
||||
]
|
||||
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
|
||||
model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
|
||||
return _xception('xception41', pretrained=pretrained, **model_args)
|
||||
|
||||
|
||||
@ -215,7 +211,7 @@ def xception65(pretrained=False, **kwargs):
|
||||
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
|
||||
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
|
||||
]
|
||||
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
|
||||
model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
|
||||
return _xception('xception65', pretrained=pretrained, **model_args)
|
||||
|
||||
|
||||
@ -236,5 +232,5 @@ def xception71(pretrained=False, **kwargs):
|
||||
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
|
||||
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
|
||||
]
|
||||
model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs)
|
||||
model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
|
||||
return _xception('xception71', pretrained=pretrained, **model_args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user