Merge branch 'yoniaflalo-adding_ECA_resnet'
commit
7a9942a75e
|
@ -1,8 +1,11 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import os
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from timm.models.layers.conv2d_same import Conv2dSame
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_path, use_ema=False):
|
||||
|
@ -98,7 +101,96 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def extract_layer(model, layer):
|
||||
layer = layer.split('.')
|
||||
module = model
|
||||
if hasattr(model, 'module') and layer[0] != 'module':
|
||||
module = model.module
|
||||
if not hasattr(model, 'module') and layer[0] == 'module':
|
||||
layer = layer[1:]
|
||||
for l in layer:
|
||||
if hasattr(module, l):
|
||||
if not l.isdigit():
|
||||
module = getattr(module, l)
|
||||
else:
|
||||
module = module[int(l)]
|
||||
else:
|
||||
return module
|
||||
return module
|
||||
|
||||
|
||||
def set_layer(model, layer, val):
|
||||
layer = layer.split('.')
|
||||
module = model
|
||||
if hasattr(model, 'module') and layer[0] != 'module':
|
||||
module = model.module
|
||||
lst_index = 0
|
||||
module2 = module
|
||||
for l in layer:
|
||||
if hasattr(module2, l):
|
||||
if not l.isdigit():
|
||||
module2 = getattr(module2, l)
|
||||
else:
|
||||
module2 = module2[int(l)]
|
||||
lst_index += 1
|
||||
lst_index -= 1
|
||||
for l in layer[:lst_index]:
|
||||
if not l.isdigit():
|
||||
module = getattr(module, l)
|
||||
else:
|
||||
module = module[int(l)]
|
||||
l = layer[lst_index]
|
||||
setattr(module, l, val)
|
||||
|
||||
|
||||
def adapt_model_from_string(parent_module, model_string):
|
||||
separator = '***'
|
||||
state_dict = {}
|
||||
lst_shape = model_string.split(separator)
|
||||
for k in lst_shape:
|
||||
k = k.split(':')
|
||||
key = k[0]
|
||||
shape = k[1][1:-1].split(',')
|
||||
if shape[0] != '':
|
||||
state_dict[key] = [int(i) for i in shape]
|
||||
|
||||
new_module = deepcopy(parent_module)
|
||||
for n, m in parent_module.named_modules():
|
||||
old_module = extract_layer(parent_module, n)
|
||||
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
||||
if isinstance(old_module, Conv2dSame):
|
||||
conv = Conv2dSame
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
s = state_dict[n + '.weight']
|
||||
in_channels = s[1]
|
||||
out_channels = s[0]
|
||||
g = 1
|
||||
if old_module.groups > 1:
|
||||
in_channels = out_channels
|
||||
g = in_channels
|
||||
new_conv = conv(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
||||
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
||||
groups=g, stride=old_module.stride)
|
||||
set_layer(new_module, n, new_conv)
|
||||
if isinstance(old_module, nn.BatchNorm2d):
|
||||
new_bn = nn.BatchNorm2d(
|
||||
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||
affine=old_module.affine, track_running_stats=True)
|
||||
set_layer(new_module, n, new_bn)
|
||||
if isinstance(old_module, nn.Linear):
|
||||
new_fc = nn.Linear(
|
||||
in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features,
|
||||
bias=old_module.bias is not None)
|
||||
set_layer(new_module, n, new_fc)
|
||||
new_module.eval()
|
||||
parent_module.eval()
|
||||
|
||||
return new_module
|
||||
|
||||
|
||||
def adapt_model_from_file(parent_module, model_variant):
|
||||
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
|
||||
with open(adapt_file, 'r') as f:
|
||||
return adapt_model_from_string(parent_module, f.read().strip())
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
|||
conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022]
|
|
@ -11,11 +11,10 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .helpers import load_pretrained, adapt_model_from_file
|
||||
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
|
@ -104,6 +103,21 @@ default_cfgs = {
|
|||
interpolation='bicubic'),
|
||||
'ecaresnet18': _cfg(),
|
||||
'ecaresnet50': _cfg(),
|
||||
'ecaresnetlight': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
|
||||
interpolation='bicubic'),
|
||||
'ecaresnet50d': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth',
|
||||
interpolation='bicubic'),
|
||||
'ecaresnet50d_pruned': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
|
||||
interpolation='bicubic'),
|
||||
'ecaresnet101d': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
|
||||
interpolation='bicubic'),
|
||||
'ecaresnet101d_pruned': _cfg(
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
||||
interpolation='bicubic'),
|
||||
}
|
||||
|
||||
|
||||
|
@ -1022,3 +1036,81 @@ def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50-D model with eca.
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnet50d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet50d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50-D model pruned with eca.
|
||||
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
||||
"""
|
||||
variant = 'ecaresnet50d_pruned'
|
||||
default_cfg = default_cfgs[variant]
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnetlight(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50-D light model with eca.
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnetlight']
|
||||
model = ResNet(
|
||||
Bottleneck, [1, 1, 11, 3], stem_width=32, avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet101d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-101-D model with eca.
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnet101d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-101-D model pruned with eca.
|
||||
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
||||
"""
|
||||
variant = 'ecaresnet101d_pruned'
|
||||
default_cfg = default_cfgs[variant]
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue