ResNetV2 pre-act and non-preact model, w/ BiT pretrained weights and support for ViT R50 model. Tweaks for in21k num_classes passing. More to do... tests failing.
parent
de6046e213
commit
231d04e91a
|
@ -16,6 +16,7 @@ from .regnet import *
|
|||
from .res2net import *
|
||||
from .resnest import *
|
||||
from .resnet import *
|
||||
from .resnetv2 import *
|
||||
from .rexnet import *
|
||||
from .selecsls import *
|
||||
from .senet import *
|
||||
|
|
|
@ -6,8 +6,6 @@ from .layers import set_layer_config
|
|||
def create_model(
|
||||
model_name,
|
||||
pretrained=False,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
scriptable=None,
|
||||
exportable=None,
|
||||
|
@ -18,8 +16,6 @@ def create_model(
|
|||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
num_classes (int): number of classes for final fully connected layer (default: 1000)
|
||||
in_chans (int): number of input channels / colors (default: 3)
|
||||
checkpoint_path (str): path of checkpoint to load after model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
|
@ -30,7 +26,7 @@ def create_model(
|
|||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are model specific
|
||||
"""
|
||||
model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
model_args = dict(pretrained=pretrained)
|
||||
|
||||
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
|
||||
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
|
||||
|
|
|
@ -11,7 +11,7 @@ from typing import Callable
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch.hub import get_dir, load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
|
||||
|
||||
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
||||
from .layers import Conv2dSame, Linear
|
||||
|
@ -88,15 +88,70 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
|
|||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
|
||||
def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
|
||||
r"""Loads a custom (read non .pth) weight file
|
||||
|
||||
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
|
||||
a passed in custom load fun, or the `load_pretrained` model member fn.
|
||||
|
||||
If the object is already present in `model_dir`, it's deserialized and returned.
|
||||
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
|
||||
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
|
||||
|
||||
Args:
|
||||
model: The instantiated model to load weights into
|
||||
cfg (dict): Default pretrained model cfg
|
||||
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
|
||||
'laod_pretrained' on the model will be called if it exists
|
||||
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
|
||||
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
|
||||
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
||||
digits of the SHA256 hash of the contents of the file. The hash is used to
|
||||
ensure unique names and to verify the contents of the file. Default: False
|
||||
"""
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||
_logger.warning("Pretrained model URL is invalid, using random initialization.")
|
||||
_logger.warning("Pretrained model URL does not exist, using random initialization.")
|
||||
return
|
||||
url = cfg['url']
|
||||
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
|
||||
if load_fn is not None:
|
||||
load_fn(model, cached_file)
|
||||
elif hasattr(model, 'load_pretrained'):
|
||||
model.load_pretrained(cached_file)
|
||||
else:
|
||||
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
|
||||
|
||||
|
||||
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||
_logger.warning("Pretrained model URL does not exist, using random initialization.")
|
||||
return
|
||||
|
||||
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
||||
|
||||
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
|
@ -269,6 +324,7 @@ def build_model_with_cfg(
|
|||
feature_cfg: dict = None,
|
||||
pretrained_strict: bool = True,
|
||||
pretrained_filter_fn: Callable = None,
|
||||
pretrained_custom_load: bool = False,
|
||||
**kwargs):
|
||||
pruned = kwargs.pop('pruned', False)
|
||||
features = False
|
||||
|
@ -289,10 +345,13 @@ def build_model_with_cfg(
|
|||
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
||||
if pretrained_custom_load:
|
||||
load_custom_pretrained(model)
|
||||
else:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
||||
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
|
|
|
@ -7,7 +7,7 @@ from .classifier import ClassifierHead, create_classifier
|
|||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
from .conv2d_same import Conv2dSame
|
||||
from .conv2d_same import Conv2dSame, conv2d_same
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import create_attn
|
||||
|
@ -20,8 +20,8 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
|
|||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .norm_act import BatchNormAct2d
|
||||
from .padding import get_padding
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .se import SEModule
|
||||
from .selective_kernel import SelectiveKernelConv
|
||||
|
|
|
@ -9,31 +9,43 @@ from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
|||
from .linear import Linear
|
||||
|
||||
|
||||
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
flatten = not use_conv # flatten when we use a Linear layer after pooling
|
||||
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
|
||||
if not pool_type:
|
||||
assert num_classes == 0 or use_conv,\
|
||||
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
|
||||
flatten = False # disable flattening if pooling is pass-through (no pooling)
|
||||
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten)
|
||||
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
|
||||
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
|
||||
num_pooled_features = num_features * global_pool.feat_mult()
|
||||
return global_pool, num_pooled_features
|
||||
|
||||
|
||||
def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
if num_classes <= 0:
|
||||
fc = nn.Identity() # pass-through (no classifier)
|
||||
elif use_conv:
|
||||
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
|
||||
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
|
||||
else:
|
||||
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
|
||||
fc = Linear(num_pooled_features, num_classes, bias=True)
|
||||
fc = Linear(num_features, num_classes, bias=True)
|
||||
return fc
|
||||
|
||||
|
||||
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
|
||||
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
return global_pool, fc
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Classifier head w/ configurable global pooling and dropout."""
|
||||
|
||||
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
|
||||
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type)
|
||||
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
self.flatten_after_fc = use_conv and pool_type
|
||||
|
||||
def forward(self, x):
|
||||
x = self.global_pool(x)
|
||||
|
|
|
@ -68,8 +68,8 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
|||
|
||||
|
||||
class GroupNormAct(nn.GroupNorm):
|
||||
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True,
|
||||
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
|
||||
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True,
|
||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
|
||||
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
|
||||
if isinstance(act_layer, str):
|
||||
|
|
|
@ -403,7 +403,7 @@ class ReductionCell1(nn.Module):
|
|||
class NASNetALarge(nn.Module):
|
||||
"""NASNetALarge (6 @ 4032) """
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2,
|
||||
def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2,
|
||||
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
|
||||
super(NASNetALarge, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
|
|
|
@ -0,0 +1,578 @@
|
|||
"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization.
|
||||
|
||||
A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code
|
||||
at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have
|
||||
been included here as pretrained models from their original .NPZ checkpoints.
|
||||
|
||||
Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and
|
||||
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
|
||||
https://github.com/google-research/vision_transformer
|
||||
|
||||
Thanks to the Google team for the above two repositories and associated papers.
|
||||
|
||||
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
|
||||
"""
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict # pylint: disable=g-importing-member
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7),
|
||||
'crop_pct': 1.0, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# pretrained on imagenet21k, finetuned on imagenet1k
|
||||
'resnetv2_50x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'),
|
||||
'resnetv2_50x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'),
|
||||
'resnetv2_101x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'),
|
||||
'resnetv2_101x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'),
|
||||
'resnetv2_152x2_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'),
|
||||
'resnetv2_152x4_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'),
|
||||
|
||||
# trained on imagenet-21k
|
||||
'resnetv2_50x1_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_50x3_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_101x1_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_101x3_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_152x2_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_152x4_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
|
||||
num_classes=21843),
|
||||
|
||||
|
||||
# trained on imagenet-1k
|
||||
'resnetv2_50x1_bits': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-S-R50x1-ILSVRC2012.npz'),
|
||||
'resnetv2_50x3_bits': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-S-R50x3-ILSVRC2012.npz'),
|
||||
'resnetv2_101x1_bits': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'),
|
||||
'resnetv2_101x3_bits': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'),
|
||||
'resnetv2_152x2_bits': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-S-R152x2-ILSVRC2012.npz'),
|
||||
'resnetv2_152x4_bits': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-S-R152x4-ILSVRC2012.npz'),
|
||||
}
|
||||
|
||||
|
||||
def make_div(v, divisor=8):
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class StdConv2d(nn.Conv2d):
|
||||
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
in_channel, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=bias, groups=groups)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight
|
||||
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
w = (w - m) / (torch.sqrt(v) + self.eps)
|
||||
x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
class StdConv2dSame(nn.Conv2d):
|
||||
"""StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model.
|
||||
"""
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
in_channel, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=bias, groups=groups)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight
|
||||
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
w = (w - m) / (torch.sqrt(v) + self.eps)
|
||||
x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
def tf2th(conv_weights):
|
||||
"""Possibly convert HWIO to OIHW."""
|
||||
if conv_weights.ndim == 4:
|
||||
conv_weights = conv_weights.transpose([3, 2, 0, 1])
|
||||
return torch.from_numpy(conv_weights)
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
"""Pre-activation (v2) bottleneck block.
|
||||
|
||||
Follows the implementation of "Identity Mappings in Deep Residual Networks":
|
||||
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
|
||||
|
||||
Except it puts the stride on 3x3 conv when available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
|
||||
super().__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_div(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
|
||||
conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.norm1 = norm_layer(in_chs)
|
||||
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
||||
self.norm2 = norm_layer(mid_chs)
|
||||
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
||||
self.norm3 = norm_layer(mid_chs)
|
||||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x_preact = self.norm1(x)
|
||||
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x_preact)
|
||||
|
||||
# residual branch
|
||||
x = self.conv1(x_preact)
|
||||
x = self.conv2(self.norm2(x))
|
||||
x = self.conv3(self.norm3(x))
|
||||
x = self.drop_path(x)
|
||||
return x + shortcut
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
|
||||
"""
|
||||
def __init__(
|
||||
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
|
||||
super().__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
act_layer = act_layer or nn.ReLU
|
||||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_div(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
|
||||
conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
||||
self.norm1 = norm_layer(mid_chs)
|
||||
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
||||
self.norm2 = norm_layer(mid_chs)
|
||||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.norm3 = norm_layer(out_chs, apply_act=False)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.act3 = act_layer(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
# residual
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.norm3(x)
|
||||
x = self.act3(x + shortcut)
|
||||
return x
|
||||
|
||||
|
||||
class DownsampleConv(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
|
||||
conv_layer=None, norm_layer=None):
|
||||
super(DownsampleConv, self).__init__()
|
||||
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
|
||||
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.conv(x))
|
||||
|
||||
|
||||
class DownsampleAvg(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
|
||||
preact=True, conv_layer=None, norm_layer=None):
|
||||
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
|
||||
super(DownsampleAvg, self).__init__()
|
||||
avg_stride = stride if dilation == 1 else 1
|
||||
if stride > 1 or dilation > 1:
|
||||
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||
else:
|
||||
self.pool = nn.Identity()
|
||||
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
|
||||
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.conv(self.pool(x)))
|
||||
|
||||
|
||||
class ResNetStage(nn.Module):
|
||||
"""ResNet Stage."""
|
||||
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
|
||||
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
|
||||
super(ResNetStage, self).__init__()
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
proj_layer = DownsampleAvg if avg_down else DownsampleConv
|
||||
prev_chs = in_chs
|
||||
self.blocks = nn.Sequential()
|
||||
for block_idx in range(depth):
|
||||
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
|
||||
stride = stride if block_idx == 0 else 1
|
||||
self.blocks.add_module(str(block_idx), block_fn(
|
||||
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
|
||||
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
|
||||
**layer_kwargs, **block_kwargs))
|
||||
prev_chs = out_chs
|
||||
first_dilation = dilation
|
||||
proj_layer = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None):
|
||||
stem = OrderedDict()
|
||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
||||
|
||||
# NOTE conv padding mode can be changed by overriding the conv_layer def
|
||||
if 'deep' in stem_type:
|
||||
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
||||
mid_chs = out_chs // 2
|
||||
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
||||
else:
|
||||
# The usual 7x7 stem conv
|
||||
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
||||
|
||||
if not preact:
|
||||
stem['norm'] = norm_layer(out_chs)
|
||||
|
||||
if 'fixed' in stem_type:
|
||||
# 'fixed' SAME padding approximation that is used in BiT models
|
||||
stem['pad'] = nn.ConstantPad2d(1, 0)
|
||||
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
|
||||
elif 'same' in stem_type:
|
||||
# full, input size based 'SAME' padding, used in ViT Hybrid model
|
||||
stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
|
||||
else:
|
||||
# the usual PyTorch symmetric padding
|
||||
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
return nn.Sequential(stem)
|
||||
|
||||
|
||||
class ResNetV2(nn.Module):
|
||||
"""Implementation of Pre-activation (v2) ResNet mode.
|
||||
"""
|
||||
|
||||
def __init__(self, layers, channels=(256, 512, 1024, 2048),
|
||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
||||
drop_rate=0., drop_path_rate=0.):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
wf = width_factor
|
||||
|
||||
self.feature_info = []
|
||||
stem_chs = make_div(stem_chs * wf)
|
||||
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
if not preact:
|
||||
self.feature_info.append(dict(num_chs=stem_chs, reduction=4, module='stem'))
|
||||
|
||||
prev_chs = stem_chs
|
||||
curr_stride = 4
|
||||
dilation = 1
|
||||
block_fn = PreActBottleneck if preact else Bottleneck
|
||||
self.stages = nn.Sequential()
|
||||
for stage_idx, (d, c) in enumerate(zip(layers, channels)):
|
||||
out_chs = make_div(c * wf)
|
||||
stride = 1 if stage_idx == 0 else 2
|
||||
if curr_stride >= output_stride:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
if preact:
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}.norm1')]
|
||||
stage = ResNetStage(
|
||||
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
||||
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
|
||||
prev_chs = out_chs
|
||||
curr_stride *= stride
|
||||
if not preact:
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
|
||||
self.stages.add_module(str(stage_idx), stage)
|
||||
|
||||
self.num_features = prev_chs
|
||||
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
|
||||
if preact:
|
||||
self.feature_info += [dict(num_chs=self.num_features, reduction=curr_stride, module=f'norm')]
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
|
||||
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
if not self.head.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
return x
|
||||
|
||||
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
||||
import numpy as np
|
||||
weights = np.load(checkpoint_path)
|
||||
with torch.no_grad():
|
||||
self.stem.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
|
||||
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
||||
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
||||
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
||||
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(self.stages.named_children()):
|
||||
for j, (bname, block) in enumerate(stage.blocks.named_children()):
|
||||
convname = 'standardized_conv2d'
|
||||
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
|
||||
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
|
||||
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel']))
|
||||
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel']))
|
||||
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma']))
|
||||
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma']))
|
||||
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma']))
|
||||
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
|
||||
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta']))
|
||||
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta']))
|
||||
if block.downsample is not None:
|
||||
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
|
||||
block.downsample.conv.weight.copy_(tf2th(w))
|
||||
|
||||
|
||||
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||
# FIXME feature map extraction is not setup properly for pre-activation mode right now
|
||||
return build_model_with_cfg(
|
||||
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
|
||||
feature_cfg=dict(flatten_sequential=True), **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bits(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bits', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bits(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bits', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bits(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bits', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bits(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bits', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bits(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bits', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bits(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bits', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
|
|
@ -23,11 +23,13 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, StdConv2dSame
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
|
@ -43,14 +45,19 @@ def _cfg(url='', **kwargs):
|
|||
|
||||
|
||||
default_cfgs = {
|
||||
# patch models
|
||||
# patch models (my experiments)
|
||||
'vit_small_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||
),
|
||||
|
||||
# patch models (weights ported from official JAX impl)
|
||||
'vit_base_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
),
|
||||
'vit_base_patch32_224': _cfg(
|
||||
url='', # no official model weights for this combo, only for in21k
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_base_patch16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
|
@ -60,15 +67,38 @@ default_cfgs = {
|
|||
'vit_large_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch32_224': _cfg(
|
||||
url='', # no official model weights for this combo, only for in21k
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_large_patch32_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_huge_patch16_224': _cfg(),
|
||||
'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)),
|
||||
# hybrid models
|
||||
|
||||
# patch models, imagenet21k (weights ported from official JAX impl)
|
||||
'vit_base_patch16_224_in21k': _cfg(
|
||||
url='',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_base_patch32_224_in21k': _cfg(
|
||||
url='',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch16_224_in21k': _cfg(
|
||||
url='',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch32_224_in21k': _cfg(
|
||||
url='',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_huge_patch14_224_in21k': _cfg(
|
||||
url='',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
|
||||
# hybrid models (weights ported from official JAX impl)
|
||||
'vit_base_resnet50_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
|
||||
# hybrid models (my experiments)
|
||||
'vit_small_resnet26d_224': _cfg(),
|
||||
'vit_small_resnet50d_s3_224': _cfg(),
|
||||
'vit_base_resnet26d_224': _cfg(),
|
||||
|
@ -184,20 +214,26 @@ class HybridEmbed(nn.Module):
|
|||
training = backbone.training
|
||||
if training:
|
||||
backbone.eval()
|
||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
||||
if isinstance(o, (list, tuple)):
|
||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||
feature_size = o.shape[-2:]
|
||||
feature_dim = o.shape[1]
|
||||
backbone.train(training)
|
||||
else:
|
||||
feature_size = to_2tuple(feature_size)
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
if hasattr(self.backbone, 'feature_info'):
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
self.num_patches = feature_size[0] * feature_size[1]
|
||||
self.proj = nn.Linear(feature_dim, embed_dim)
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)[-1]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.proj(x)
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -205,8 +241,8 @@ class VisionTransformer(nn.Module):
|
|||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, representation_size=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
|
@ -231,9 +267,14 @@ class VisionTransformer(nn.Module):
|
|||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
||||
#self.repr = nn.Linear(embed_dim, representation_size)
|
||||
#self.repr_act = nn.Tanh()
|
||||
# Representation layer
|
||||
if representation_size:
|
||||
self.pre_logits = nn.Sequential(OrderedDict([
|
||||
('fc', nn.Linear(embed_dim, representation_size)),
|
||||
('act', nn.Tanh())
|
||||
]))
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
@ -279,6 +320,7 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.pre_logits(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
@ -318,6 +360,17 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch32_224']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -351,6 +404,17 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch32_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch32_224']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
|
@ -374,17 +438,72 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch16_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_huge_patch16_224']
|
||||
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, num_classes=21843, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch16_224_in21k']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch32_384(pretrained=False, **kwargs):
|
||||
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_huge_patch32_384']
|
||||
img_size=224, num_classes=21843, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch32_224_in21k']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, num_classes=21843, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch16_224_in21k']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=224, num_classes=21843, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch32_224_in21k']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=224, patch_size=14, num_classes=21843, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
||||
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
||||
backbone = ResNetV2(
|
||||
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
|
||||
model = VisionTransformer(
|
||||
img_size=384, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone,
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_resnet50_384']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
|
|
4
train.py
4
train.py
|
@ -76,8 +76,8 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
|||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||
parser.add_argument('--no-resume-opt', action='store_true', default=False,
|
||||
help='prevent resume of optimizer state when resuming model')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
|
||||
help='number of label classes (Model default if None)')
|
||||
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||
|
|
Loading…
Reference in New Issue