mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Replace all None by nn.Identity() in all models reset_classifier when False-values num_classes is given.
Make small code refactoring
This commit is contained in:
parent
6cc11a8821
commit
a7ebe09029
@ -2,17 +2,17 @@
|
||||
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
|
||||
fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
||||
"""
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
import re
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['DenseNet']
|
||||
|
||||
@ -85,6 +85,7 @@ class DenseNet(nn.Module):
|
||||
drop_rate (float) - dropout rate after each dense layer
|
||||
num_classes (int) - number of classification classes
|
||||
"""
|
||||
|
||||
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
|
||||
num_init_features=64, bn_size=4, drop_rate=0,
|
||||
num_classes=1000, in_chans=3, global_pool='avg'):
|
||||
@ -127,8 +128,11 @@ class DenseNet(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.classifier = nn.Linear(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.classifier = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.classifier = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.features(x)
|
||||
@ -157,7 +161,6 @@ def _filter_pretrained(state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
|
@ -11,11 +11,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['DLA']
|
||||
|
||||
@ -51,6 +50,7 @@ default_cfgs = {
|
||||
|
||||
class DlaBasic(nn.Module):
|
||||
"""DLA Basic"""
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
|
||||
super(DlaBasic, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
@ -170,7 +170,7 @@ class DlaBottle2neck(nn.Module):
|
||||
sp = bn(sp)
|
||||
sp = self.relu(sp)
|
||||
spo.append(sp)
|
||||
if self.scale > 1 :
|
||||
if self.scale > 1:
|
||||
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
|
||||
out = torch.cat(spo, 1)
|
||||
|
||||
@ -304,9 +304,10 @@ class DLA(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
if num_classes:
|
||||
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True)
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.fc = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
|
||||
else:
|
||||
self.fc = None
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.base_layer(x)
|
||||
|
@ -9,16 +9,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['DPN']
|
||||
|
||||
@ -218,8 +218,8 @@ class DPN(nn.Module):
|
||||
|
||||
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.classifier = nn.Conv2d(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True)
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classifier
|
||||
@ -228,10 +228,10 @@ class DPN(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
if num_classes:
|
||||
self.classifier = nn.Conv2d(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True)
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True)
|
||||
else:
|
||||
self.classifier = None
|
||||
self.classifier = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
return self.features(x)
|
||||
|
@ -24,14 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .efficientnet_builder import *
|
||||
from .feature_hooks import FeatureHooks
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained, adapt_model_from_file
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.models.layers import create_conv2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['EfficientNet']
|
||||
|
||||
@ -373,8 +371,11 @@ class EfficientNet(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.classifier = nn.Linear(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.classifier = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.classifier = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.conv_stem(x)
|
||||
@ -785,13 +786,13 @@ def _gen_efficientnet_condconv(
|
||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||
['ir_r2_k3_s2_e6_c24_se0.25'],
|
||||
['ir_r2_k5_s2_e6_c40_se0.25'],
|
||||
['ir_r3_k3_s2_e6_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
|
||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||
['ir_r2_k3_s2_e6_c24_se0.25'],
|
||||
['ir_r2_k5_s2_e6_c40_se0.25'],
|
||||
['ir_r3_k3_s2_e6_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
|
||||
]
|
||||
# NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
|
||||
# the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
|
||||
@ -1187,6 +1188,7 @@ def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
||||
pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
||||
""" EfficientNet-CondConv-B1 w/ 8 Experts """
|
||||
@ -1242,8 +1244,6 @@ def efficientnet_lite4(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_b1_pruned(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
||||
@ -1275,8 +1275,6 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B0. Tensorflow compatible variant """
|
||||
@ -1619,6 +1617,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
||||
pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
||||
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
|
||||
@ -1764,4 +1763,3 @@ def tf_mixnet_l(pretrained=False, **kwargs):
|
||||
model = _gen_mixnet_m(
|
||||
'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
@ -3,17 +3,11 @@ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-R
|
||||
and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
|
||||
by Ross Wightman
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SEModule
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet, Bottleneck, BasicBlock
|
||||
|
||||
|
||||
@ -202,8 +196,8 @@ def gluon_resnet50_v1e(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_width=64, stem_type='deep', avg_down=True, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
#if pretrained:
|
||||
# load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -6,15 +6,15 @@ Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xcep
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
|
||||
from .registry import register_model
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['Xception65', 'Xception71']
|
||||
|
||||
@ -47,7 +47,6 @@ default_cfgs = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
""" PADDING NOTES
|
||||
The original PyTorch and Gluon impl of these models dutifully reproduced the
|
||||
aligned padding added to Tensorflow models for Deeplab. This padding was compensating
|
||||
@ -223,7 +222,7 @@ class Xception65(nn.Module):
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True)
|
||||
|
||||
# Middle flow
|
||||
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
||||
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
||||
728, 728, num_reps=3, stride=1, dilation=middle_block_dilation,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True))
|
||||
for i in range(4, 20)]))
|
||||
@ -333,7 +332,7 @@ class Xception71(nn.Module):
|
||||
exit_block_dilations = (2, 4)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Entry flow
|
||||
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = norm_layer(num_features=32, **norm_kwargs)
|
||||
@ -394,7 +393,11 @@ class Xception71(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.fc = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
# Entry flow
|
||||
@ -465,4 +468,3 @@ def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
@ -6,10 +6,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['InceptionResnetV2']
|
||||
|
||||
@ -296,8 +296,11 @@ class InceptionResnetV2(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_classes = num_classes
|
||||
self.classif = nn.Linear(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.classif = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.classif = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.conv2d_1a(x)
|
||||
|
@ -1,7 +1,8 @@
|
||||
from torchvision.models import Inception3
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import load_pretrained
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
@ -6,10 +6,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['InceptionV4']
|
||||
|
||||
@ -280,8 +280,11 @@ class InceptionV4(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_classes = num_classes
|
||||
self.last_linear = nn.Linear(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.last_linear = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.last_linear = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
return self.features(x)
|
||||
@ -303,6 +306,3 @@ def inception_v4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
@ -8,13 +8,13 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .efficientnet_builder import *
|
||||
from .registry import register_model
|
||||
from .feature_hooks import FeatureHooks
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d, create_conv2d
|
||||
from .layers.activations import HardSwish, hard_sigmoid
|
||||
from .feature_hooks import FeatureHooks
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['MobileNetV3']
|
||||
|
||||
@ -76,7 +76,7 @@ class MobileNetV3(nn.Module):
|
||||
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_features = num_features
|
||||
self.drop_rate = drop_rate
|
||||
@ -96,7 +96,7 @@ class MobileNetV3(nn.Module):
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
|
||||
# Head + Pooling
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||
@ -120,8 +120,11 @@ class MobileNetV3(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_classes = num_classes
|
||||
self.classifier = nn.Linear(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes) if self.num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.classifier = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.classifier = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.conv_stem(x)
|
||||
@ -397,7 +400,6 @@ def mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||
print(kwargs)
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
@ -2,10 +2,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['NASNetALarge']
|
||||
|
||||
@ -187,17 +186,17 @@ class CellStem1(nn.Module):
|
||||
self.stem_size = stem_size
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_channels, self.num_channels, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(2 * self.num_channels, self.num_channels, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.path_1 = nn.Sequential()
|
||||
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False))
|
||||
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
||||
self.path_2 = nn.ModuleList()
|
||||
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
||||
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False))
|
||||
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
||||
|
||||
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
|
||||
@ -507,50 +506,50 @@ class NASNetALarge(nn.Module):
|
||||
self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2))
|
||||
self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier)
|
||||
|
||||
self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels//2,
|
||||
in_channels_right=2*channels, out_channels_right=channels)
|
||||
self.cell_1 = NormalCell(in_channels_left=2*channels, out_channels_left=channels,
|
||||
in_channels_right=6*channels, out_channels_right=channels)
|
||||
self.cell_2 = NormalCell(in_channels_left=6*channels, out_channels_left=channels,
|
||||
in_channels_right=6*channels, out_channels_right=channels)
|
||||
self.cell_3 = NormalCell(in_channels_left=6*channels, out_channels_left=channels,
|
||||
in_channels_right=6*channels, out_channels_right=channels)
|
||||
self.cell_4 = NormalCell(in_channels_left=6*channels, out_channels_left=channels,
|
||||
in_channels_right=6*channels, out_channels_right=channels)
|
||||
self.cell_5 = NormalCell(in_channels_left=6*channels, out_channels_left=channels,
|
||||
in_channels_right=6*channels, out_channels_right=channels)
|
||||
self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels // 2,
|
||||
in_channels_right=2 * channels, out_channels_right=channels)
|
||||
self.cell_1 = NormalCell(in_channels_left=2 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_2 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_3 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_4 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_5 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
|
||||
self.reduction_cell_0 = ReductionCell0(in_channels_left=6*channels, out_channels_left=2*channels,
|
||||
in_channels_right=6*channels, out_channels_right=2*channels)
|
||||
self.reduction_cell_0 = ReductionCell0(in_channels_left=6 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=6 * channels, out_channels_right=2 * channels)
|
||||
|
||||
self.cell_6 = FirstCell(in_channels_left=6*channels, out_channels_left=channels,
|
||||
in_channels_right=8*channels, out_channels_right=2*channels)
|
||||
self.cell_7 = NormalCell(in_channels_left=8*channels, out_channels_left=2*channels,
|
||||
in_channels_right=12*channels, out_channels_right=2*channels)
|
||||
self.cell_8 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
||||
in_channels_right=12*channels, out_channels_right=2*channels)
|
||||
self.cell_9 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
||||
in_channels_right=12*channels, out_channels_right=2*channels)
|
||||
self.cell_10 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
||||
in_channels_right=12*channels, out_channels_right=2*channels)
|
||||
self.cell_11 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
||||
in_channels_right=12*channels, out_channels_right=2*channels)
|
||||
self.cell_6 = FirstCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=8 * channels, out_channels_right=2 * channels)
|
||||
self.cell_7 = NormalCell(in_channels_left=8 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.cell_8 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.cell_9 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.cell_10 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.cell_11 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
|
||||
self.reduction_cell_1 = ReductionCell1(in_channels_left=12*channels, out_channels_left=4*channels,
|
||||
in_channels_right=12*channels, out_channels_right=4*channels)
|
||||
self.reduction_cell_1 = ReductionCell1(in_channels_left=12 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=4 * channels)
|
||||
|
||||
self.cell_12 = FirstCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
||||
in_channels_right=16*channels, out_channels_right=4*channels)
|
||||
self.cell_13 = NormalCell(in_channels_left=16*channels, out_channels_left=4*channels,
|
||||
in_channels_right=24*channels, out_channels_right=4*channels)
|
||||
self.cell_14 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels,
|
||||
in_channels_right=24*channels, out_channels_right=4*channels)
|
||||
self.cell_15 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels,
|
||||
in_channels_right=24*channels, out_channels_right=4*channels)
|
||||
self.cell_16 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels,
|
||||
in_channels_right=24*channels, out_channels_right=4*channels)
|
||||
self.cell_17 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels,
|
||||
in_channels_right=24*channels, out_channels_right=4*channels)
|
||||
self.cell_12 = FirstCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=16 * channels, out_channels_right=4 * channels)
|
||||
self.cell_13 = NormalCell(in_channels_left=16 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
self.cell_14 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
self.cell_15 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
self.cell_16 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
self.cell_17 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
@ -562,9 +561,11 @@ class NASNetALarge(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
del self.last_linear
|
||||
self.last_linear = nn.Linear(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.last_linear = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.last_linear = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x_conv0 = self.conv0(x)
|
||||
|
@ -6,15 +6,16 @@
|
||||
|
||||
"""
|
||||
from __future__ import print_function, division, absolute_import
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['PNASNet5Large']
|
||||
|
||||
@ -349,11 +350,11 @@ class PNASNet5Large(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
del self.last_linear
|
||||
if num_classes:
|
||||
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.last_linear = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.last_linear = None
|
||||
self.last_linear = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x_conv_0 = self.conv_0(x)
|
||||
|
@ -6,13 +6,11 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .resnet import ResNet
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SEModule
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet
|
||||
|
||||
__all__ = []
|
||||
|
||||
@ -105,7 +103,7 @@ class Bottle2neck(nn.Module):
|
||||
sp = bn(sp)
|
||||
sp = self.relu(sp)
|
||||
spo.append(sp)
|
||||
if self.scale > 1 :
|
||||
if self.scale > 1:
|
||||
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
|
||||
out = torch.cat(spo, 1)
|
||||
|
||||
|
@ -10,10 +10,10 @@ import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained, adapt_model_from_file
|
||||
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
@ -377,6 +377,7 @@ class ResNet(nn.Module):
|
||||
global_pool : str, default 'avg'
|
||||
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
||||
"""
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||
@ -482,8 +483,11 @@ class ResNet(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_classes = num_classes
|
||||
del self.fc
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.fc = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.conv1(x)
|
||||
|
@ -9,16 +9,15 @@ https://arxiv.org/abs/1907.00837
|
||||
Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models
|
||||
and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
@ -134,11 +133,11 @@ class SelecSLS(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_classes = num_classes
|
||||
del self.fc
|
||||
if num_classes:
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.fc = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.fc = None
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
|
@ -8,16 +8,16 @@ Original model: https://github.com/hujie-frank/SENet
|
||||
ResNet code gently borrowed from
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['SENet']
|
||||
|
||||
@ -369,11 +369,11 @@ class SENet(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
del self.last_linear
|
||||
if num_classes:
|
||||
self.last_linear = nn.Linear(self.num_features * self.avg_pool.feat_mult(), num_classes)
|
||||
num_features = self.num_features * self.avg_pool.feat_mult()
|
||||
self.last_linear = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.last_linear = None
|
||||
self.last_linear = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.layer0(x)
|
||||
|
@ -5,14 +5,16 @@ https://arxiv.org/pdf/2003.13630.pdf
|
||||
Original model: https://github.com/mrT23/TResNet
|
||||
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
|
||||
try:
|
||||
from inplace_abn import InPlaceABN
|
||||
@ -88,7 +90,7 @@ class FastSEModule(nn.Module):
|
||||
|
||||
|
||||
def IABN2Float(module: nn.Module) -> nn.Module:
|
||||
"If `module` is IABN don't use half precision."
|
||||
"""If `module` is IABN don't use half precision."""
|
||||
if isinstance(module, InPlaceABN):
|
||||
module.float()
|
||||
for child in module.children():
|
||||
@ -277,8 +279,10 @@ class TResNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head = None
|
||||
if num_classes:
|
||||
self.head = nn.Sequential(OrderedDict([
|
||||
('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))]))
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.head = nn.Sequential(OrderedDict([('fc', nn.Linear(num_features, num_classes))]))
|
||||
else:
|
||||
self.head = nn.Sequential(OrderedDict([('fc', nn.Identity())]))
|
||||
|
||||
def forward_features(self, x):
|
||||
return self.body(x)
|
||||
|
@ -21,15 +21,13 @@ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
||||
|
||||
The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['Xception']
|
||||
|
||||
@ -180,8 +178,11 @@ class Xception(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
del self.fc
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.fc = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.conv1(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user