mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Working on feature extraction, interfaces refined, a number of models working, some in progress.
This commit is contained in:
parent
24e7535278
commit
d23a2697d0
@ -7,3 +7,4 @@ from .transforms_factory import create_transform
|
||||
from .mixup import mixup_batch, FastCollateMixup
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
from .real_labels import RealLabelsImagenet
|
||||
|
36
timm/data/real_labels.py
Normal file
36
timm/data/real_labels.py
Normal file
@ -0,0 +1,36 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RealLabelsImagenet:
|
||||
|
||||
def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
|
||||
with open(real_json) as real_labels:
|
||||
real_labels = json.load(real_labels)
|
||||
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
|
||||
self.real_labels = real_labels
|
||||
self.filenames = filenames
|
||||
assert len(self.filenames) == len(self.real_labels)
|
||||
self.topk = topk
|
||||
self.is_correct = {k: [] for k in topk}
|
||||
self.sample_idx = 0
|
||||
|
||||
def add_result(self, output):
|
||||
maxk = max(self.topk)
|
||||
_, pred_batch = output.topk(maxk, 1, True, True)
|
||||
pred_batch = pred_batch.cpu().numpy()
|
||||
for pred in pred_batch:
|
||||
filename = self.filenames[self.sample_idx]
|
||||
filename = os.path.basename(filename)
|
||||
if self.real_labels[filename]:
|
||||
for k in self.topk:
|
||||
self.is_correct[k].append(
|
||||
any([p in self.real_labels[filename] for p in pred[:k]]))
|
||||
self.sample_idx += 1
|
||||
|
||||
def get_accuracy(self, k=None):
|
||||
if k is None:
|
||||
return {k: float(np.mean(self.is_correct[k] for k in self.topk))}
|
||||
else:
|
||||
return float(np.mean(self.is_correct[k])) * 100
|
@ -13,6 +13,7 @@ import torch.utils.checkpoint as cp
|
||||
from torch.jit.annotations import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .features import FeatureNet
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
|
||||
from .registry import register_model
|
||||
@ -199,6 +200,9 @@ class DenseNet(nn.Module):
|
||||
('norm0', norm_layer(num_init_features)),
|
||||
('pool0', stem_pool),
|
||||
]))
|
||||
self.feature_info = [
|
||||
dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')]
|
||||
current_stride = 4
|
||||
|
||||
# DenseBlocks
|
||||
num_features = num_init_features
|
||||
@ -212,21 +216,27 @@ class DenseNet(nn.Module):
|
||||
drop_rate=drop_rate,
|
||||
memory_efficient=memory_efficient
|
||||
)
|
||||
self.features.add_module('denseblock%d' % (i + 1), block)
|
||||
module_name = f'denseblock{(i + 1)}'
|
||||
self.features.add_module(module_name, block)
|
||||
num_features = num_features + num_layers * growth_rate
|
||||
transition_aa_layer = None if aa_stem_only else aa_layer
|
||||
if i != len(block_config) - 1:
|
||||
self.feature_info += [
|
||||
dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
|
||||
current_stride *= 2
|
||||
trans = DenseTransition(
|
||||
num_input_features=num_features, num_output_features=num_features // 2,
|
||||
norm_layer=norm_layer, aa_layer=transition_aa_layer)
|
||||
self.features.add_module('transition%d' % (i + 1), trans)
|
||||
self.features.add_module(f'transition{i + 1}', trans)
|
||||
num_features = num_features // 2
|
||||
|
||||
# Final batch norm
|
||||
self.features.add_module('norm5', norm_layer(num_features))
|
||||
|
||||
# Linear layer
|
||||
self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')]
|
||||
self.num_features = num_features
|
||||
|
||||
# Linear layer
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
|
||||
@ -279,16 +289,14 @@ def _filter_torchvision_pretrained(state_dict):
|
||||
|
||||
|
||||
def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
||||
features = False
|
||||
out_indices = None
|
||||
if kwargs.pop('features_only', False):
|
||||
assert False, 'Not Implemented' # TODO
|
||||
load_strict = False
|
||||
features = True
|
||||
kwargs.pop('num_classes', 0)
|
||||
model_class = DenseNet
|
||||
else:
|
||||
load_strict = True
|
||||
model_class = DenseNet
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
default_cfg = default_cfgs[variant]
|
||||
model = model_class(growth_rate=growth_rate, block_config=block_config, **kwargs)
|
||||
model = DenseNet(growth_rate=growth_rate, block_config=block_config, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
@ -296,7 +304,9 @@ def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=_filter_torchvision_pretrained,
|
||||
strict=load_strict)
|
||||
strict=not features)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices, flatten_sequential=True)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||
from .feature_hooks import FeatureHooks
|
||||
from .features import FeatureInfo
|
||||
from .helpers import load_pretrained, adapt_model_from_file
|
||||
from .layers import SelectAdaptivePool2d, create_conv2d
|
||||
from .registry import register_model
|
||||
@ -438,42 +439,22 @@ class EfficientNetFeatures(nn.Module):
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self._feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||
self._stage_to_feature_idx = {
|
||||
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
||||
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
if _DEBUG:
|
||||
for k, v in self._feature_info.items():
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
||||
for fi, v in enumerate(self.feature_info):
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
|
||||
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
self.feature_hooks = None
|
||||
if feature_location != 'bottleneck':
|
||||
hooks = [dict(
|
||||
name=self._feature_info[idx]['module'],
|
||||
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
||||
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
def feature_channels(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._feature_info[idx]['num_chs']
|
||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
|
||||
def feature_info(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._feature_info[idx]
|
||||
return [self._feature_info[i] for i in self.out_indices]
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
|
@ -225,7 +225,7 @@ class EfficientNetBuilder:
|
||||
|
||||
# state updated during build, consumed by model
|
||||
self.in_chs = None
|
||||
self.features = OrderedDict()
|
||||
self.features = []
|
||||
|
||||
def _round_channels(self, chs):
|
||||
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
@ -291,7 +291,6 @@ class EfficientNetBuilder:
|
||||
total_block_idx = 0
|
||||
current_stride = 2
|
||||
current_dilation = 1
|
||||
feature_idx = 0
|
||||
stages = []
|
||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||
for stage_idx, stage_block_args in enumerate(model_block_args):
|
||||
@ -351,13 +350,15 @@ class EfficientNetBuilder:
|
||||
# stash feature module name and channel info for model feature extraction
|
||||
if extract_features:
|
||||
feature_info = block.feature_info(extract_features)
|
||||
if feature_info['module']:
|
||||
feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module']
|
||||
module_name = f'blocks.{stage_idx}.{block_idx}'
|
||||
if 'module' in feature_info and feature_info['module']:
|
||||
feature_info['module'] = '.'.join([module_name, feature_info['module']])
|
||||
else:
|
||||
feature_info['module'] = module_name
|
||||
feature_info['stage_idx'] = stage_idx
|
||||
feature_info['block_idx'] = block_idx
|
||||
feature_info['reduction'] = current_stride
|
||||
self.features[feature_idx] = feature_info
|
||||
feature_idx += 1
|
||||
self.features.append(feature_info)
|
||||
|
||||
total_block_idx += 1 # incr global block idx (across all stacks)
|
||||
stages.append(nn.Sequential(*blocks))
|
||||
|
@ -1,3 +1,9 @@
|
||||
""" PyTorch Feature Hook Helper
|
||||
|
||||
This class helps gather features from a network via hooks specified on the module name.
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
|
||||
from collections import defaultdict, OrderedDict
|
||||
@ -7,20 +13,21 @@ from typing import List
|
||||
|
||||
class FeatureHooks:
|
||||
|
||||
def __init__(self, hooks, named_modules):
|
||||
def __init__(self, hooks, named_modules, output_as_dict=False):
|
||||
# setup feature hooks
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for h in hooks:
|
||||
hook_name = h['name']
|
||||
hook_name = h['module']
|
||||
m = modules[hook_name]
|
||||
hook_fn = partial(self._collect_output_hook, hook_name)
|
||||
if h['type'] == 'forward_pre':
|
||||
if h['hook_type'] == 'forward_pre':
|
||||
m.register_forward_pre_hook(hook_fn)
|
||||
elif h['type'] == 'forward':
|
||||
elif h['hook_type'] == 'forward':
|
||||
m.register_forward_hook(hook_fn)
|
||||
else:
|
||||
assert False, "Unsupported hook type"
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
self.output_as_dict = output_as_dict
|
||||
|
||||
def _collect_output_hook(self, name, *args):
|
||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||
@ -29,6 +36,9 @@ class FeatureHooks:
|
||||
self._feature_outputs[x.device][name] = x
|
||||
|
||||
def get_output(self, device) -> List[torch.tensor]:
|
||||
output = list(self._feature_outputs[device].values())
|
||||
if self.output_as_dict:
|
||||
output = self._feature_outputs[device]
|
||||
else:
|
||||
output = list(self._feature_outputs[device].values())
|
||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||
return output
|
||||
|
251
timm/models/features.py
Normal file
251
timm/models/features.py
Normal file
@ -0,0 +1,251 @@
|
||||
""" PyTorch Feature Extraction Helpers
|
||||
|
||||
A collection of classes, functions, modules to help extract features from models
|
||||
and provide a common interface for describing them.
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Tuple, Any
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FeatureInfo:
|
||||
|
||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||
prev_reduction = 1
|
||||
for fi in feature_info:
|
||||
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||
assert 'num_chs' in fi and fi['num_chs'] > 0
|
||||
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
||||
prev_reduction = fi['reduction']
|
||||
assert 'module' in fi
|
||||
self._out_indices = out_indices
|
||||
self._info = feature_info
|
||||
|
||||
def from_other(self, out_indices: Tuple[int]):
|
||||
return FeatureInfo(deepcopy(self._info), out_indices)
|
||||
|
||||
def channels(self, idx=None):
|
||||
""" feature channels accessor
|
||||
if idx == None, returns feature channel count at each output index
|
||||
if idx is an integer, return feature channel count for that feature module index
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._info[idx]['num_chs']
|
||||
return [self._info[i]['num_chs'] for i in self._out_indices]
|
||||
|
||||
def reduction(self, idx=None):
|
||||
""" feature reduction (output stride) accessor
|
||||
if idx == None, returns feature reduction factor at each output index
|
||||
if idx is an integer, return feature channel count at that feature module index
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._info[idx]['reduction']
|
||||
return [self._info[i]['reduction'] for i in self._out_indices]
|
||||
|
||||
def module_name(self, idx=None):
|
||||
""" feature module name accessor
|
||||
if idx == None, returns feature module name at each output index
|
||||
if idx is an integer, return feature module name at that feature module index
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._info[idx]['module']
|
||||
return [self._info[i]['module'] for i in self._out_indices]
|
||||
|
||||
def get_by_key(self, idx=None, keys=None):
|
||||
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._info[idx] if keys is None else {k: self._info[idx][k] for k in keys}
|
||||
if keys is None:
|
||||
return [self._info[i] for i in self._out_indices]
|
||||
else:
|
||||
return [{k: self._info[i][k] for k in keys} for i in self._out_indices]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._info[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._info)
|
||||
|
||||
|
||||
def _module_list(module, flatten_sequential=False):
|
||||
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
||||
ml = []
|
||||
for name, module in module.named_children():
|
||||
if flatten_sequential and isinstance(module, nn.Sequential):
|
||||
# first level of Sequential containers is flattened into containing model
|
||||
for child_name, child_module in module.named_children():
|
||||
ml.append(('_'.join([name, child_name]), child_module))
|
||||
else:
|
||||
ml.append((name, module))
|
||||
return ml
|
||||
|
||||
|
||||
def _check_return_layers(input_return_layers, modules):
|
||||
return_layers = {}
|
||||
for k, v in input_return_layers.items():
|
||||
ks = k.split('.')
|
||||
assert 0 < len(ks) <= 2
|
||||
return_layers['_'.join(ks)] = v
|
||||
return_set = set(return_layers.keys())
|
||||
sdiff = return_set - {name for name, _ in modules}
|
||||
if sdiff:
|
||||
raise ValueError(f'return_layers {sdiff} are not present in model')
|
||||
return return_layers, return_set
|
||||
|
||||
|
||||
class LayerGetterDict(nn.ModuleDict):
|
||||
"""
|
||||
Module wrapper that returns intermediate layers from a model as a dictionary
|
||||
|
||||
Originally based on IntermediateLayerGetter at
|
||||
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||
|
||||
It has a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. This means that one should **not** reuse the same nn.Module twice
|
||||
in the forward if you want this to work.
|
||||
|
||||
Additionally, it is only able to query submodules that are directly assigned to the model
|
||||
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so
|
||||
long as `features` is a sequential container assigned to the model).
|
||||
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model on which we will extract the features
|
||||
return_layers (Dict[name, new_name]): a dict containing the names
|
||||
of the modules for which the activations will be returned as
|
||||
the key of the dict, and the value of the dict is the name
|
||||
of the returned activation (which the user can specify).
|
||||
concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
self.return_layers, remaining = _check_return_layers(return_layers, modules)
|
||||
layers = OrderedDict()
|
||||
self.concat = concat
|
||||
for name, module in modules:
|
||||
layers[name] = module
|
||||
if name in remaining:
|
||||
remaining.remove(name)
|
||||
if not remaining:
|
||||
break
|
||||
super(LayerGetterDict, self).__init__(layers)
|
||||
|
||||
def forward(self, x) -> Dict[Any, torch.Tensor]:
|
||||
out = OrderedDict()
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
out_id = self.return_layers[name]
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
||||
else:
|
||||
out[out_id] = x
|
||||
return out
|
||||
|
||||
|
||||
class LayerGetterList(nn.Sequential):
|
||||
"""
|
||||
Module wrapper that returns intermediate layers from a model as a list
|
||||
|
||||
Originally based on IntermediateLayerGetter at
|
||||
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||
|
||||
It has a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. This means that one should **not** reuse the same nn.Module twice
|
||||
in the forward if you want this to work.
|
||||
|
||||
Additionally, it is only able to query submodules that are directly assigned to the model
|
||||
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so
|
||||
long as `features` is a sequential container assigned to the model and flatten_sequent=True.
|
||||
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model on which we will extract the features
|
||||
return_layers (Dict[name, new_name]): a dict containing the names
|
||||
of the modules for which the activations will be returned as
|
||||
the key of the dict, and the value of the dict is the name
|
||||
of the returned activation (which the user can specify).
|
||||
concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
|
||||
super(LayerGetterList, self).__init__()
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
self.return_layers, remaining = _check_return_layers(return_layers, modules)
|
||||
self.concat = concat
|
||||
for name, module in modules:
|
||||
self.add_module(name, module)
|
||||
if name in remaining:
|
||||
remaining.remove(name)
|
||||
if not remaining:
|
||||
break
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
out = []
|
||||
for name, module in self.named_children():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out.append(torch.cat(x, 1) if self.concat else x[0])
|
||||
else:
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
def _resolve_feature_info(net, out_indices, feature_info=None):
|
||||
if feature_info is None:
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
if isinstance(feature_info, FeatureInfo):
|
||||
return feature_info.from_other(out_indices)
|
||||
elif isinstance(feature_info, (list, tuple)):
|
||||
return FeatureInfo(net.feature_info, out_indices)
|
||||
else:
|
||||
assert False, "Provided feature_info is not valid"
|
||||
|
||||
|
||||
class FeatureNet(nn.Module):
|
||||
""" FeatureNet
|
||||
|
||||
Wrap a model and extract features as specified by the out indices, the network
|
||||
is partially re-built from contained modules using the LayerGetters.
|
||||
|
||||
Please read the docstrings of the LayerGetter classes, they will not work on all models.
|
||||
"""
|
||||
def __init__(
|
||||
self, net,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False,
|
||||
feature_info=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureNet, self).__init__()
|
||||
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
|
||||
module_names = self.feature_info.module_name()
|
||||
return_layers = {}
|
||||
for i in range(len(out_indices)):
|
||||
return_layers[module_names[i]] = out_map[i] if out_map is not None else out_indices[i]
|
||||
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
|
||||
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.body(x)
|
||||
return output
|
@ -13,16 +13,16 @@ import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from .layers import SelectAdaptivePool2d, get_padding
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['Xception65', 'Xception71']
|
||||
__all__ = ['Xception65']
|
||||
|
||||
default_cfgs = {
|
||||
'gluon_xception65': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'crop_pct': 0.903,
|
||||
'pool_size': (10, 10),
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
@ -32,52 +32,13 @@ default_cfgs = {
|
||||
'classifier': 'fc'
|
||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||
},
|
||||
'gluon_xception71': {
|
||||
'url': '',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'pool_size': (5, 5),
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv1',
|
||||
'classifier': 'fc'
|
||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||
}
|
||||
}
|
||||
|
||||
""" PADDING NOTES
|
||||
The original PyTorch and Gluon impl of these models dutifully reproduced the
|
||||
aligned padding added to Tensorflow models for Deeplab. This padding was compensating
|
||||
for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to.
|
||||
|
||||
So, I'm phasing out the 'fixed_padding' ported from TF and replacing with normal
|
||||
PyTorch padding, some asserts to validate the equivalence for any scenario we'd
|
||||
care about before removing altogether.
|
||||
"""
|
||||
_USE_FIXED_PAD = False
|
||||
|
||||
|
||||
def _pytorch_padding(kernel_size, stride=1, dilation=1, **_):
|
||||
if _USE_FIXED_PAD:
|
||||
return 0 # FIXME remove once verified
|
||||
else:
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
|
||||
# FIXME remove once verified
|
||||
fp = _fixed_padding(kernel_size, dilation)
|
||||
assert all(padding == p for p in fp)
|
||||
|
||||
return padding
|
||||
|
||||
|
||||
def _fixed_padding(kernel_size, dilation):
|
||||
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
pad_total = kernel_size_effective - 1
|
||||
pad_beg = pad_total // 2
|
||||
pad_end = pad_total - pad_beg
|
||||
return [pad_beg, pad_end, pad_beg, pad_end]
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
@ -88,24 +49,16 @@ class SeparableConv2d(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
|
||||
padding = _fixed_padding(self.kernel_size, self.dilation)
|
||||
if _USE_FIXED_PAD and any(p > 0 for p in padding):
|
||||
self.fixed_padding = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
self.fixed_padding = None
|
||||
|
||||
# depthwise convolution
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
self.conv_dw = nn.Conv2d(
|
||||
inplanes, inplanes, kernel_size, stride=stride,
|
||||
padding=_pytorch_padding(kernel_size, stride, dilation), dilation=dilation, groups=inplanes, bias=bias)
|
||||
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
|
||||
self.bn = norm_layer(num_features=inplanes, **norm_kwargs)
|
||||
# pointwise convolution
|
||||
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
if self.fixed_padding is not None:
|
||||
# FIXME remove once verified
|
||||
x = self.fixed_padding(x)
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn(x)
|
||||
x = self.conv_pw(x)
|
||||
@ -113,58 +66,37 @@ class SeparableConv2d(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, inplanes, planes, num_reps, stride=1, dilation=1, norm_layer=None,
|
||||
norm_kwargs=None, start_with_relu=True, grow_first=True, is_last=False):
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True,
|
||||
norm_layer=None, norm_kwargs=None, ):
|
||||
super(Block, self).__init__()
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
if planes != inplanes or stride != 1:
|
||||
if isinstance(planes, (list, tuple)):
|
||||
assert len(planes) == 3
|
||||
else:
|
||||
planes = (planes,) * 3
|
||||
outplanes = planes[-1]
|
||||
|
||||
if outplanes != inplanes or stride != 1:
|
||||
self.skip = nn.Sequential()
|
||||
self.skip.add_module('conv1', nn.Conv2d(
|
||||
inplanes, planes, 1, stride=stride, bias=False)),
|
||||
self.skip.add_module('bn1', norm_layer(num_features=planes, **norm_kwargs))
|
||||
inplanes, outplanes, 1, stride=stride, bias=False)),
|
||||
self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs))
|
||||
else:
|
||||
self.skip = None
|
||||
|
||||
rep = OrderedDict()
|
||||
l = 1
|
||||
filters = inplanes
|
||||
if grow_first:
|
||||
if start_with_relu:
|
||||
rep['act%d' % l] = nn.ReLU(inplace=False) # NOTE: silent failure if inplace=True here
|
||||
rep['conv%d' % l] = SeparableConv2d(
|
||||
inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
|
||||
filters = planes
|
||||
l += 1
|
||||
for i in range(3):
|
||||
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
|
||||
rep['conv%d' % (i + 1)] = SeparableConv2d(
|
||||
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs)
|
||||
inplanes = planes[i]
|
||||
|
||||
for _ in range(num_reps - 1):
|
||||
if grow_first or start_with_relu:
|
||||
# FIXME being conservative with inplace here, think it's fine to leave True?
|
||||
rep['act%d' % l] = nn.ReLU(inplace=grow_first or not start_with_relu)
|
||||
rep['conv%d' % l] = SeparableConv2d(
|
||||
filters, filters, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
rep['bn%d' % l] = norm_layer(num_features=filters, **norm_kwargs)
|
||||
l += 1
|
||||
|
||||
if not grow_first:
|
||||
rep['act%d' % l] = nn.ReLU(inplace=True)
|
||||
rep['conv%d' % l] = SeparableConv2d(
|
||||
inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
|
||||
l += 1
|
||||
|
||||
if stride != 1:
|
||||
rep['act%d' % l] = nn.ReLU(inplace=True)
|
||||
rep['conv%d' % l] = SeparableConv2d(
|
||||
planes, planes, 3, stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
|
||||
l += 1
|
||||
elif is_last:
|
||||
rep['act%d' % l] = nn.ReLU(inplace=True)
|
||||
rep['conv%d' % l] = SeparableConv2d(
|
||||
planes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
|
||||
l += 1
|
||||
if not start_with_relu:
|
||||
del rep['act1']
|
||||
else:
|
||||
rep['act1'] = nn.ReLU(inplace=False)
|
||||
self.rep = nn.Sequential(rep)
|
||||
|
||||
def forward(self, x):
|
||||
@ -176,7 +108,10 @@ class Block(nn.Module):
|
||||
|
||||
|
||||
class Xception65(nn.Module):
|
||||
"""Modified Aligned Xception
|
||||
"""Modified Aligned Xception.
|
||||
|
||||
NOTE: only the 65 layer version is included here, the 71 layer variant
|
||||
was not correct and had no pretrained weights
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
||||
@ -212,25 +147,21 @@ class Xception65(nn.Module):
|
||||
self.bn2 = norm_layer(num_features=64)
|
||||
|
||||
self.block1 = Block(
|
||||
64, 128, num_reps=2, stride=2,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False)
|
||||
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.block2 = Block(
|
||||
128, 256, num_reps=2, stride=2,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True)
|
||||
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.block3 = Block(
|
||||
256, 728, num_reps=2, stride=entry_block3_stride,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True)
|
||||
256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
|
||||
# Middle flow
|
||||
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
||||
728, 728, num_reps=3, stride=1, dilation=middle_block_dilation,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True))
|
||||
for i in range(4, 20)]))
|
||||
728, 728, stride=1, dilation=middle_block_dilation,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)]))
|
||||
|
||||
# Exit flow
|
||||
self.block20 = Block(
|
||||
728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True)
|
||||
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
|
||||
self.conv3 = SeparableConv2d(
|
||||
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
||||
@ -305,147 +236,6 @@ class Xception65(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Xception71(nn.Module):
|
||||
"""Modified Aligned Xception
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
||||
norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
||||
super(Xception71, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||
if output_stride == 32:
|
||||
entry_block3_stride = 2
|
||||
exit_block20_stride = 2
|
||||
middle_block_dilation = 1
|
||||
exit_block_dilations = (1, 1)
|
||||
elif output_stride == 16:
|
||||
entry_block3_stride = 2
|
||||
exit_block20_stride = 1
|
||||
middle_block_dilation = 1
|
||||
exit_block_dilations = (1, 2)
|
||||
elif output_stride == 8:
|
||||
entry_block3_stride = 1
|
||||
exit_block20_stride = 1
|
||||
middle_block_dilation = 2
|
||||
exit_block_dilations = (2, 4)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Entry flow
|
||||
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = norm_layer(num_features=32, **norm_kwargs)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = norm_layer(num_features=64)
|
||||
|
||||
self.block1 = Block(
|
||||
64, 128, num_reps=2, stride=2, norm_layer=norm_layer,
|
||||
norm_kwargs=norm_kwargs, start_with_relu=False)
|
||||
self.block2 = nn.Sequential(*[
|
||||
Block(
|
||||
128, 256, num_reps=2, stride=1, norm_layer=norm_layer,
|
||||
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True),
|
||||
Block(
|
||||
256, 256, num_reps=2, stride=2, norm_layer=norm_layer,
|
||||
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True),
|
||||
Block(
|
||||
256, 728, num_reps=2, stride=2, norm_layer=norm_layer,
|
||||
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True)])
|
||||
self.block3 = Block(
|
||||
728, 728, num_reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
|
||||
norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True)
|
||||
|
||||
# Middle flow
|
||||
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
||||
728, 728, num_reps=3, stride=1, dilation=middle_block_dilation,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True))
|
||||
for i in range(4, 20)]))
|
||||
|
||||
# Exit flow
|
||||
self.block20 = Block(
|
||||
728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True)
|
||||
|
||||
self.conv3 = SeparableConv2d(
|
||||
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
|
||||
|
||||
self.conv4 = SeparableConv2d(
|
||||
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
|
||||
|
||||
self.num_features = 2048
|
||||
self.conv5 = SeparableConv2d(
|
||||
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
if num_classes:
|
||||
num_features = self.num_features * self.global_pool.feat_mult()
|
||||
self.fc = nn.Linear(num_features, num_classes)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
# Entry flow
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.block1(x)
|
||||
# add relu here
|
||||
x = self.relu(x)
|
||||
# low_level_feat = x
|
||||
x = self.block2(x)
|
||||
# c2 = x
|
||||
x = self.block3(x)
|
||||
|
||||
# Middle flow
|
||||
x = self.mid(x)
|
||||
# c3 = x
|
||||
|
||||
# Exit flow
|
||||
x = self.block20(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv4(x)
|
||||
x = self.bn4(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv5(x)
|
||||
x = self.bn5(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.global_pool(x).flatten(1)
|
||||
if self.drop_rate:
|
||||
F.dropout(x, self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" Modified Aligned Xception-65
|
||||
@ -456,15 +246,3 @@ def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" Modified Aligned Xception-71
|
||||
"""
|
||||
default_cfg = default_cfgs['gluon_xception71']
|
||||
model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .features import FeatureNet
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from .registry import register_model
|
||||
@ -231,9 +232,13 @@ class InceptionResnetV2(nn.Module):
|
||||
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
||||
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]
|
||||
|
||||
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
||||
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
||||
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
||||
self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]
|
||||
|
||||
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
||||
self.mixed_5b = Mixed_5b()
|
||||
self.repeat = nn.Sequential(
|
||||
@ -248,6 +253,8 @@ class InceptionResnetV2(nn.Module):
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17)
|
||||
)
|
||||
self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]
|
||||
|
||||
self.mixed_6a = Mixed_6a()
|
||||
self.repeat_1 = nn.Sequential(
|
||||
Block17(scale=0.10),
|
||||
@ -271,6 +278,8 @@ class InceptionResnetV2(nn.Module):
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10)
|
||||
)
|
||||
self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]
|
||||
|
||||
self.mixed_7a = Mixed_7a()
|
||||
self.repeat_2 = nn.Sequential(
|
||||
Block8(scale=0.20),
|
||||
@ -285,6 +294,8 @@ class InceptionResnetV2(nn.Module):
|
||||
)
|
||||
self.block8 = Block8(no_relu=True)
|
||||
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
||||
self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
|
||||
self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
@ -328,30 +339,34 @@ class InceptionResnetV2(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""InceptionResnetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
||||
"""
|
||||
default_cfg = default_cfgs['inception_resnet_v2']
|
||||
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
def _inception_resnet_v2(variant, pretrained=False, **kwargs):
|
||||
load_strict, features, out_indices = True, False, None
|
||||
if kwargs.pop('features_only', False):
|
||||
load_strict, features, out_indices = False, True, kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
kwargs.pop('num_classes', 0)
|
||||
model = InceptionResnetV2(**kwargs)
|
||||
model.default_cfg = default_cfgs[variant]
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def inception_resnet_v2(pretrained=False, **kwargs):
|
||||
r"""InceptionResnetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
||||
"""
|
||||
return _inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ens_adv_inception_resnet_v2(pretrained=False, **kwargs):
|
||||
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
|
||||
As per https://arxiv.org/abs/1705.07204 and
|
||||
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
|
||||
"""
|
||||
default_cfg = default_cfgs['ens_adv_inception_resnet_v2']
|
||||
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
||||
return model
|
||||
return _inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)
|
||||
|
@ -17,6 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||
from .feature_hooks import FeatureHooks
|
||||
from .features import FeatureInfo
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
|
||||
from .registry import register_model
|
||||
@ -182,22 +183,20 @@ class MobileNetV3Features(nn.Module):
|
||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self._feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||
self._stage_to_feature_idx = {
|
||||
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
||||
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
if _DEBUG:
|
||||
for k, v in self._feature_info.items():
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
||||
for fi, v in enumerate(self.feature_info):
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
|
||||
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
self.feature_hooks = None
|
||||
if feature_location != 'bottleneck':
|
||||
hooks = [dict(
|
||||
name=self._feature_info[idx]['module'],
|
||||
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
||||
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
def feature_channels(self, idx=None):
|
||||
@ -206,17 +205,8 @@ class MobileNetV3Features(nn.Module):
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._feature_info[idx]['num_chs']
|
||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
|
||||
def feature_info(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._feature_info[idx]
|
||||
return [self._feature_info[i] for i in self.out_indices]
|
||||
return self.feature_info[idx]['num_chs']
|
||||
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
x = self.conv_stem(x)
|
||||
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['NASNetALarge']
|
||||
@ -24,43 +24,31 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
class MaxPoolPad(nn.Module):
|
||||
class ActConvBn(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MaxPoolPad, self).__init__()
|
||||
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
|
||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
|
||||
super(ActConvBn, self).__init__()
|
||||
self.act = nn.ReLU()
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pad(x)
|
||||
x = self.pool(x)
|
||||
x = x[:, :, 1:, 1:]
|
||||
return x
|
||||
|
||||
|
||||
class AvgPoolPad(nn.Module):
|
||||
|
||||
def __init__(self, stride=2, padding=1):
|
||||
super(AvgPoolPad, self).__init__()
|
||||
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
|
||||
self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pad(x)
|
||||
x = self.pool(x)
|
||||
x = x[:, :, 1:, 1:]
|
||||
x = self.act(x)
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
self.depthwise_conv2d = nn.Conv2d(
|
||||
in_channels, in_channels, dw_kernel,
|
||||
stride=dw_stride, padding=dw_padding,
|
||||
bias=bias, groups=in_channels)
|
||||
self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias)
|
||||
self.depthwise_conv2d = create_conv2d(
|
||||
in_channels, in_channels, kernel_size=kernel_size,
|
||||
stride=stride, padding=padding, groups=in_channels)
|
||||
self.pointwise_conv2d = create_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.depthwise_conv2d(x)
|
||||
@ -70,87 +58,48 @@ class SeparableConv2d(nn.Module):
|
||||
|
||||
class BranchSeparables(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False):
|
||||
super(BranchSeparables, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias)
|
||||
self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias)
|
||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
middle_channels = out_channels if stem_cell else in_channels
|
||||
self.act_1 = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(
|
||||
in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type)
|
||||
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1)
|
||||
self.act_2 = nn.ReLU(inplace=True)
|
||||
self.separable_2 = SeparableConv2d(
|
||||
middle_channels, out_channels, kernel_size, stride=1, padding=pad_type)
|
||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.act_1(x)
|
||||
x = self.separable_1(x)
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class BranchSeparablesStem(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
|
||||
super(BranchSeparablesStem, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
|
||||
self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias)
|
||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.separable_1(x)
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class BranchSeparablesReduction(BranchSeparables):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False):
|
||||
BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias)
|
||||
self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.padding(x)
|
||||
x = self.separable_1(x)
|
||||
x = x[:, :, 1:, 1:].contiguous()
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.act_2(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class CellStem0(nn.Module):
|
||||
def __init__(self, stem_size, num_channels=42):
|
||||
def __init__(self, stem_size, num_channels=42, pad_type=''):
|
||||
super(CellStem0, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.stem_size = stem_size
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
|
||||
self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2)
|
||||
self.comb_iter_0_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False)
|
||||
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
||||
self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
|
||||
|
||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_1_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False)
|
||||
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
||||
self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
||||
self.comb_iter_2_right = BranchSeparablesStem(self.stem_size, self.num_channels, 5, 2, 2, bias=False)
|
||||
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
||||
self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False)
|
||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
|
||||
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv_1x1(x)
|
||||
@ -180,51 +129,46 @@ class CellStem0(nn.Module):
|
||||
|
||||
class CellStem1(nn.Module):
|
||||
|
||||
def __init__(self, stem_size, num_channels):
|
||||
def __init__(self, stem_size, num_channels, pad_type=''):
|
||||
super(CellStem1, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.stem_size = stem_size
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(2 * self.num_channels, self.num_channels, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
|
||||
self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.act = nn.ReLU()
|
||||
self.path_1 = nn.Sequential()
|
||||
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
||||
self.path_2 = nn.ModuleList()
|
||||
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
||||
|
||||
self.path_2 = nn.Sequential()
|
||||
self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
|
||||
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
||||
|
||||
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False)
|
||||
self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False)
|
||||
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
||||
self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
|
||||
|
||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False)
|
||||
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
||||
self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
||||
self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False)
|
||||
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
||||
self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False)
|
||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
|
||||
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
||||
|
||||
def forward(self, x_conv0, x_stem_0):
|
||||
x_left = self.conv_1x1(x_stem_0)
|
||||
|
||||
x_relu = self.relu(x_conv0)
|
||||
x_relu = self.act(x_conv0)
|
||||
# path 1
|
||||
x_path1 = self.path_1(x_relu)
|
||||
# path 2
|
||||
x_path2 = self.path_2.pad(x_relu)
|
||||
x_path2 = x_path2[:, :, 1:, 1:]
|
||||
x_path2 = self.path_2.avgpool(x_path2)
|
||||
x_path2 = self.path_2.conv(x_path2)
|
||||
x_path2 = self.path_2(x_relu)
|
||||
# final path
|
||||
x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
||||
|
||||
@ -253,49 +197,40 @@ class CellStem1(nn.Module):
|
||||
|
||||
class FirstCell(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
||||
super(FirstCell, self).__init__()
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
||||
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.act = nn.ReLU()
|
||||
self.path_1 = nn.Sequential()
|
||||
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.path_2 = nn.ModuleList()
|
||||
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
||||
self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
|
||||
|
||||
self.path_2 = nn.Sequential()
|
||||
self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
|
||||
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
|
||||
|
||||
self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True)
|
||||
self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
||||
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
||||
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||
|
||||
self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
||||
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
||||
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
|
||||
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
x_relu = self.relu(x_prev)
|
||||
# path 1
|
||||
x_relu = self.act(x_prev)
|
||||
x_path1 = self.path_1(x_relu)
|
||||
# path 2
|
||||
x_path2 = self.path_2.pad(x_relu)
|
||||
x_path2 = x_path2[:, :, 1:, 1:]
|
||||
x_path2 = self.path_2.avgpool(x_path2)
|
||||
x_path2 = self.path_2.conv(x_path2)
|
||||
# final path
|
||||
x_path2 = self.path_2(x_relu)
|
||||
x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
||||
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
@ -322,30 +257,23 @@ class FirstCell(nn.Module):
|
||||
|
||||
class NormalCell(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
||||
super(NormalCell, self).__init__()
|
||||
self.conv_prev_1x1 = nn.Sequential()
|
||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
||||
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
||||
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
||||
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
||||
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
||||
self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
||||
self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
|
||||
self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type)
|
||||
self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
|
||||
|
||||
self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False)
|
||||
self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
|
||||
self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
|
||||
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
@ -375,31 +303,24 @@ class NormalCell(nn.Module):
|
||||
|
||||
class ReductionCell0(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
||||
super(ReductionCell0, self).__init__()
|
||||
self.conv_prev_1x1 = nn.Sequential()
|
||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
||||
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
||||
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
||||
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
||||
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
||||
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
||||
self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
||||
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
||||
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
||||
|
||||
self.comb_iter_1_left = MaxPoolPad()
|
||||
self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
||||
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
||||
self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
||||
|
||||
self.comb_iter_2_left = AvgPoolPad()
|
||||
self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
||||
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
self.comb_iter_4_right = MaxPoolPad()
|
||||
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
@ -430,31 +351,24 @@ class ReductionCell0(nn.Module):
|
||||
|
||||
class ReductionCell1(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
||||
super(ReductionCell1, self).__init__()
|
||||
self.conv_prev_1x1 = nn.Sequential()
|
||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
||||
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
||||
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
||||
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
||||
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
||||
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
||||
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
||||
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
||||
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
||||
|
||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
||||
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
||||
self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
||||
self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
||||
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
@ -487,7 +401,7 @@ class NASNetALarge(nn.Module):
|
||||
"""NASNetALarge (6 @ 4032) """
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, num_features=4032, channel_multiplier=2,
|
||||
drop_rate=0., global_pool='avg'):
|
||||
drop_rate=0., global_pool='avg', pad_type='same'):
|
||||
super(NASNetALarge, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.stem_size = stem_size
|
||||
@ -498,60 +412,79 @@ class NASNetALarge(nn.Module):
|
||||
channels = self.num_features // 24
|
||||
# 24 is default value for the architecture
|
||||
|
||||
self.conv0 = nn.Sequential()
|
||||
self.conv0.add_module('conv', nn.Conv2d(
|
||||
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, bias=False))
|
||||
self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_size, eps=0.001, momentum=0.1, affine=True))
|
||||
self.conv0 = ConvBnAct(
|
||||
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
|
||||
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
|
||||
|
||||
self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2))
|
||||
self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier)
|
||||
self.cell_stem_0 = CellStem0(
|
||||
self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)
|
||||
self.cell_stem_1 = CellStem1(
|
||||
self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type)
|
||||
|
||||
self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels // 2,
|
||||
in_channels_right=2 * channels, out_channels_right=channels)
|
||||
self.cell_1 = NormalCell(in_channels_left=2 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_2 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_3 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_4 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_5 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=6 * channels, out_channels_right=channels)
|
||||
self.cell_0 = FirstCell(
|
||||
in_chs_left=channels, out_chs_left=channels // 2,
|
||||
in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||
self.cell_1 = NormalCell(
|
||||
in_chs_left=2 * channels, out_chs_left=channels,
|
||||
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||
self.cell_2 = NormalCell(
|
||||
in_chs_left=6 * channels, out_chs_left=channels,
|
||||
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||
self.cell_3 = NormalCell(
|
||||
in_chs_left=6 * channels, out_chs_left=channels,
|
||||
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||
self.cell_4 = NormalCell(
|
||||
in_chs_left=6 * channels, out_chs_left=channels,
|
||||
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||
self.cell_5 = NormalCell(
|
||||
in_chs_left=6 * channels, out_chs_left=channels,
|
||||
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||
|
||||
self.reduction_cell_0 = ReductionCell0(in_channels_left=6 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=6 * channels, out_channels_right=2 * channels)
|
||||
self.reduction_cell_0 = ReductionCell0(
|
||||
in_chs_left=6 * channels, out_chs_left=2 * channels,
|
||||
in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||
self.cell_6 = FirstCell(
|
||||
in_chs_left=6 * channels, out_chs_left=channels,
|
||||
in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||
self.cell_7 = NormalCell(
|
||||
in_chs_left=8 * channels, out_chs_left=2 * channels,
|
||||
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||
self.cell_8 = NormalCell(
|
||||
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||
self.cell_9 = NormalCell(
|
||||
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||
self.cell_10 = NormalCell(
|
||||
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||
self.cell_11 = NormalCell(
|
||||
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||
|
||||
self.cell_6 = FirstCell(in_channels_left=6 * channels, out_channels_left=channels,
|
||||
in_channels_right=8 * channels, out_channels_right=2 * channels)
|
||||
self.cell_7 = NormalCell(in_channels_left=8 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.cell_8 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.cell_9 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.cell_10 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.cell_11 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
||||
self.reduction_cell_1 = ReductionCell1(
|
||||
in_chs_left=12 * channels, out_chs_left=4 * channels,
|
||||
in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||
self.cell_12 = FirstCell(
|
||||
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||
in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||
self.cell_13 = NormalCell(
|
||||
in_chs_left=16 * channels, out_chs_left=4 * channels,
|
||||
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||
self.cell_14 = NormalCell(
|
||||
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
||||
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||
self.cell_15 = NormalCell(
|
||||
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
||||
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||
self.cell_16 = NormalCell(
|
||||
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
||||
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||
self.cell_17 = NormalCell(
|
||||
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
||||
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||
|
||||
self.reduction_cell_1 = ReductionCell1(in_channels_left=12 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=12 * channels, out_channels_right=4 * channels)
|
||||
|
||||
self.cell_12 = FirstCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
||||
in_channels_right=16 * channels, out_channels_right=4 * channels)
|
||||
self.cell_13 = NormalCell(in_channels_left=16 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
self.cell_14 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
self.cell_15 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
self.cell_16 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
self.cell_17 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.act = nn.ReLU(inplace=True)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
|
||||
@ -569,8 +502,11 @@ class NASNetALarge(nn.Module):
|
||||
|
||||
def forward_features(self, x):
|
||||
x_conv0 = self.conv0(x)
|
||||
#0
|
||||
|
||||
x_stem_0 = self.cell_stem_0(x_conv0)
|
||||
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
|
||||
#1
|
||||
|
||||
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
|
||||
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
|
||||
@ -578,25 +514,27 @@ class NASNetALarge(nn.Module):
|
||||
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
|
||||
x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
|
||||
x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
|
||||
#2
|
||||
|
||||
x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
|
||||
|
||||
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
|
||||
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
|
||||
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
|
||||
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
|
||||
x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
|
||||
x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
|
||||
#3
|
||||
|
||||
x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
|
||||
|
||||
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
|
||||
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
|
||||
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
|
||||
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
|
||||
x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
|
||||
x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
|
||||
x = self.relu(x_cell_17)
|
||||
x = self.act(x_cell_17)
|
||||
#4
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -14,7 +14,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['PNASNet5Large']
|
||||
@ -35,34 +35,15 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
class MaxPool(nn.Module):
|
||||
|
||||
def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False):
|
||||
super(MaxPool, self).__init__()
|
||||
self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None
|
||||
self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
if self.zero_pad is not None:
|
||||
x = self.zero_pad(x)
|
||||
x = self.pool(x)
|
||||
x = x[:, :, 1:, 1:]
|
||||
else:
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride,
|
||||
dw_padding):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=dw_kernel_size,
|
||||
stride=dw_stride, padding=dw_padding,
|
||||
groups=in_channels, bias=False)
|
||||
self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=1, bias=False)
|
||||
self.depthwise_conv2d = create_conv2d(
|
||||
in_channels, in_channels, kernel_size=kernel_size,
|
||||
stride=stride, padding=padding, groups=in_channels)
|
||||
self.pointwise_conv2d = create_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.depthwise_conv2d(x)
|
||||
@ -72,50 +53,39 @@ class SeparableConv2d(nn.Module):
|
||||
|
||||
class BranchSeparables(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
stem_cell=False, zero_pad=False):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''):
|
||||
super(BranchSeparables, self).__init__()
|
||||
padding = kernel_size // 2
|
||||
middle_channels = out_channels if stem_cell else in_channels
|
||||
self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None
|
||||
self.relu_1 = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(in_channels, middle_channels,
|
||||
kernel_size, dw_stride=stride,
|
||||
dw_padding=padding)
|
||||
self.act_1 = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(
|
||||
in_channels, middle_channels, kernel_size, stride=stride, padding=padding)
|
||||
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001)
|
||||
self.relu_2 = nn.ReLU()
|
||||
self.separable_2 = SeparableConv2d(middle_channels, out_channels,
|
||||
kernel_size, dw_stride=1,
|
||||
dw_padding=padding)
|
||||
self.act_2 = nn.ReLU()
|
||||
self.separable_2 = SeparableConv2d(
|
||||
middle_channels, out_channels, kernel_size, stride=1, padding=padding)
|
||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu_1(x)
|
||||
if self.zero_pad is not None:
|
||||
x = self.zero_pad(x)
|
||||
x = self.separable_1(x)
|
||||
x = x[:, :, 1:, 1:].contiguous()
|
||||
else:
|
||||
x = self.separable_1(x)
|
||||
x = self.act_1(x)
|
||||
x = self.separable_1(x)
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu_2(x)
|
||||
x = self.act_2(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ReluConvBn(nn.Module):
|
||||
class ActConvBn(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
||||
super(ReluConvBn, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
bias=False)
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
|
||||
super(ActConvBn, self).__init__()
|
||||
self.act = nn.ReLU()
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.act(x)
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
@ -123,32 +93,24 @@ class ReluConvBn(nn.Module):
|
||||
|
||||
class FactorizedReduction(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
def __init__(self, in_channels, out_channels, padding=''):
|
||||
super(FactorizedReduction, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.act = nn.ReLU()
|
||||
self.path_1 = nn.Sequential(OrderedDict([
|
||||
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
|
||||
('conv', nn.Conv2d(in_channels, out_channels // 2,
|
||||
kernel_size=1, bias=False)),
|
||||
('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
|
||||
]))
|
||||
self.path_2 = nn.Sequential(OrderedDict([
|
||||
('pad', nn.ZeroPad2d((0, 1, 0, 1))),
|
||||
('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift
|
||||
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
|
||||
('conv', nn.Conv2d(in_channels, out_channels // 2,
|
||||
kernel_size=1, bias=False)),
|
||||
('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
|
||||
]))
|
||||
self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.act(x)
|
||||
x_path1 = self.path_1(x)
|
||||
|
||||
x_path2 = self.path_2.pad(x)
|
||||
x_path2 = x_path2[:, :, 1:, 1:]
|
||||
x_path2 = self.path_2.avgpool(x_path2)
|
||||
x_path2 = self.path_2.conv(x_path2)
|
||||
|
||||
x_path2 = self.path_2(x)
|
||||
out = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
||||
return out
|
||||
|
||||
@ -179,49 +141,41 @@ class CellBase(nn.Module):
|
||||
x_comb_iter_4_right = x_right
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = torch.cat(
|
||||
[x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||
x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||
return x_out
|
||||
|
||||
|
||||
class CellStem0(CellBase):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right,
|
||||
out_channels_right):
|
||||
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding=''):
|
||||
super(CellStem0, self).__init__()
|
||||
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right,
|
||||
kernel_size=1)
|
||||
self.comb_iter_0_left = BranchSeparables(in_channels_left,
|
||||
out_channels_left,
|
||||
kernel_size=5, stride=2,
|
||||
stem_cell=True)
|
||||
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(
|
||||
in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=padding)
|
||||
self.comb_iter_0_right = nn.Sequential(OrderedDict([
|
||||
('max_pool', MaxPool(3, stride=2)),
|
||||
('conv', nn.Conv2d(in_channels_left, out_channels_left,
|
||||
kernel_size=1, bias=False)),
|
||||
('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)),
|
||||
('max_pool', create_pool2d('max', 3, stride=2, padding=padding)),
|
||||
('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=padding)),
|
||||
('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)),
|
||||
]))
|
||||
self.comb_iter_1_left = BranchSeparables(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=7, stride=2)
|
||||
self.comb_iter_1_right = MaxPool(3, stride=2)
|
||||
self.comb_iter_2_left = BranchSeparables(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=5, stride=2)
|
||||
self.comb_iter_2_right = BranchSeparables(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=3, stride=2)
|
||||
self.comb_iter_3_left = BranchSeparables(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=3)
|
||||
self.comb_iter_3_right = MaxPool(3, stride=2)
|
||||
self.comb_iter_4_left = BranchSeparables(in_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=3, stride=2,
|
||||
stem_cell=True)
|
||||
self.comb_iter_4_right = ReluConvBn(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=1, stride=2)
|
||||
|
||||
self.comb_iter_1_left = BranchSeparables(
|
||||
out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=padding)
|
||||
self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=padding)
|
||||
|
||||
self.comb_iter_2_left = BranchSeparables(
|
||||
out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=padding)
|
||||
self.comb_iter_2_right = BranchSeparables(
|
||||
out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=padding)
|
||||
|
||||
self.comb_iter_3_left = BranchSeparables(
|
||||
out_chs_right, out_chs_right, kernel_size=3, padding=padding)
|
||||
self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=padding)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(
|
||||
in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=padding)
|
||||
self.comb_iter_4_right = ActConvBn(
|
||||
out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=padding)
|
||||
|
||||
def forward(self, x_left):
|
||||
x_right = self.conv_1x1(x_left)
|
||||
@ -231,9 +185,8 @@ class CellStem0(CellBase):
|
||||
|
||||
class Cell(CellBase):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right,
|
||||
out_channels_right, is_reduction=False, zero_pad=False,
|
||||
match_prev_layer_dimensions=False):
|
||||
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding='',
|
||||
is_reduction=False, match_prev_layer_dims=False):
|
||||
super(Cell, self).__init__()
|
||||
|
||||
# If `is_reduction` is set to `True` stride 2 is used for
|
||||
@ -244,45 +197,34 @@ class Cell(CellBase):
|
||||
# If `match_prev_layer_dimensions` is set to `True`
|
||||
# `FactorizedReduction` is used to reduce the spatial size
|
||||
# of the left input of a cell approximately by a factor of 2.
|
||||
self.match_prev_layer_dimensions = match_prev_layer_dimensions
|
||||
if match_prev_layer_dimensions:
|
||||
self.conv_prev_1x1 = FactorizedReduction(in_channels_left,
|
||||
out_channels_left)
|
||||
self.match_prev_layer_dimensions = match_prev_layer_dims
|
||||
if match_prev_layer_dims:
|
||||
self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=padding)
|
||||
else:
|
||||
self.conv_prev_1x1 = ReluConvBn(in_channels_left,
|
||||
out_channels_left, kernel_size=1)
|
||||
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=padding)
|
||||
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding)
|
||||
|
||||
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right,
|
||||
kernel_size=1)
|
||||
self.comb_iter_0_left = BranchSeparables(out_channels_left,
|
||||
out_channels_left,
|
||||
kernel_size=5, stride=stride,
|
||||
zero_pad=zero_pad)
|
||||
self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
|
||||
self.comb_iter_1_left = BranchSeparables(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=7, stride=stride,
|
||||
zero_pad=zero_pad)
|
||||
self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
|
||||
self.comb_iter_2_left = BranchSeparables(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=5, stride=stride,
|
||||
zero_pad=zero_pad)
|
||||
self.comb_iter_2_right = BranchSeparables(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=3, stride=stride,
|
||||
zero_pad=zero_pad)
|
||||
self.comb_iter_3_left = BranchSeparables(out_channels_right,
|
||||
out_channels_right,
|
||||
kernel_size=3)
|
||||
self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
|
||||
self.comb_iter_4_left = BranchSeparables(out_channels_left,
|
||||
out_channels_left,
|
||||
kernel_size=3, stride=stride,
|
||||
zero_pad=zero_pad)
|
||||
self.comb_iter_0_left = BranchSeparables(
|
||||
out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=padding)
|
||||
self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=padding)
|
||||
|
||||
self.comb_iter_1_left = BranchSeparables(
|
||||
out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=padding)
|
||||
self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=padding)
|
||||
|
||||
self.comb_iter_2_left = BranchSeparables(
|
||||
out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=padding)
|
||||
self.comb_iter_2_right = BranchSeparables(
|
||||
out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=padding)
|
||||
|
||||
self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3)
|
||||
self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=padding)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(
|
||||
out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=padding)
|
||||
if is_reduction:
|
||||
self.comb_iter_4_right = ReluConvBn(
|
||||
out_channels_right, out_channels_right, kernel_size=1, stride=stride)
|
||||
self.comb_iter_4_right = ActConvBn(
|
||||
out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=padding)
|
||||
else:
|
||||
self.comb_iter_4_right = None
|
||||
|
||||
@ -294,52 +236,53 @@ class Cell(CellBase):
|
||||
|
||||
|
||||
class PNASNet5Large(nn.Module):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg'):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg', padding=''):
|
||||
super(PNASNet5Large, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 4320
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
self.conv_0 = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(in_chans, 96, kernel_size=3, stride=2, bias=False)),
|
||||
('bn', nn.BatchNorm2d(96, eps=0.001))
|
||||
]))
|
||||
self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54,
|
||||
in_channels_right=96,
|
||||
out_channels_right=54)
|
||||
self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108,
|
||||
in_channels_right=270, out_channels_right=108,
|
||||
match_prev_layer_dimensions=True,
|
||||
is_reduction=True)
|
||||
self.cell_0 = Cell(in_channels_left=270, out_channels_left=216,
|
||||
in_channels_right=540, out_channels_right=216,
|
||||
match_prev_layer_dimensions=True)
|
||||
self.cell_1 = Cell(in_channels_left=540, out_channels_left=216,
|
||||
in_channels_right=1080, out_channels_right=216)
|
||||
self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216,
|
||||
in_channels_right=1080, out_channels_right=216)
|
||||
self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216,
|
||||
in_channels_right=1080, out_channels_right=216)
|
||||
self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432,
|
||||
in_channels_right=1080, out_channels_right=432,
|
||||
is_reduction=True, zero_pad=True)
|
||||
self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432,
|
||||
in_channels_right=2160, out_channels_right=432,
|
||||
match_prev_layer_dimensions=True)
|
||||
self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432,
|
||||
in_channels_right=2160, out_channels_right=432)
|
||||
self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432,
|
||||
in_channels_right=2160, out_channels_right=432)
|
||||
self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864,
|
||||
in_channels_right=2160, out_channels_right=864,
|
||||
is_reduction=True)
|
||||
self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864,
|
||||
in_channels_right=4320, out_channels_right=864,
|
||||
match_prev_layer_dimensions=True)
|
||||
self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864,
|
||||
in_channels_right=4320, out_channels_right=864)
|
||||
self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
|
||||
in_channels_right=4320, out_channels_right=864)
|
||||
self.conv_0 = ConvBnAct(
|
||||
in_chans, 96, kernel_size=3, stride=2, padding=0,
|
||||
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
|
||||
|
||||
self.cell_stem_0 = CellStem0(
|
||||
in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, padding=padding)
|
||||
|
||||
self.cell_stem_1 = Cell(
|
||||
in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, padding=padding,
|
||||
match_prev_layer_dims=True, is_reduction=True)
|
||||
self.cell_0 = Cell(
|
||||
in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, padding=padding,
|
||||
match_prev_layer_dims=True)
|
||||
self.cell_1 = Cell(
|
||||
in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
|
||||
self.cell_2 = Cell(
|
||||
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
|
||||
self.cell_3 = Cell(
|
||||
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
|
||||
|
||||
self.cell_4 = Cell(
|
||||
in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, padding=padding,
|
||||
is_reduction=True)
|
||||
self.cell_5 = Cell(
|
||||
in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding,
|
||||
match_prev_layer_dims=True)
|
||||
self.cell_6 = Cell(
|
||||
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding)
|
||||
self.cell_7 = Cell(
|
||||
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding)
|
||||
|
||||
self.cell_8 = Cell(
|
||||
in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, padding=padding,
|
||||
is_reduction=True)
|
||||
self.cell_9 = Cell(
|
||||
in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding,
|
||||
match_prev_layer_dims=True)
|
||||
self.cell_10 = Cell(
|
||||
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding)
|
||||
self.cell_11 = Cell(
|
||||
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding)
|
||||
self.relu = nn.ReLU()
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
@ -391,7 +334,7 @@ def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
<https://arxiv.org/abs/1712.00559>`_ paper.
|
||||
"""
|
||||
default_cfg = default_cfgs['pnasnet5large']
|
||||
model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, padding='same', **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet
|
||||
from .resnet import _create_resnet_with_cfg
|
||||
|
||||
__all__ = []
|
||||
|
||||
@ -132,113 +132,83 @@ class Bottle2neck(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
def _create_res2net(variant, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs[variant]
|
||||
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net50_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Res2Net-50_26w_4s model.
|
||||
def res2net50_26w_4s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 26w4s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
default_cfg = default_cfgs['res2net50_26w_4s']
|
||||
res2net_block_args = dict(scale=4)
|
||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs)
|
||||
return _create_res2net('res2net50_26w_4s', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net101_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Res2Net-50_26w_4s model.
|
||||
def res2net101_26w_4s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-101 26w4s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
default_cfg = default_cfgs['res2net101_26w_4s']
|
||||
res2net_block_args = dict(scale=4)
|
||||
model = ResNet(Bottle2neck, [3, 4, 23, 3], base_width=26,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs)
|
||||
return _create_res2net('res2net101_26w_4s', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net50_26w_6s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Res2Net-50_26w_4s model.
|
||||
def res2net50_26w_6s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 26w6s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
default_cfg = default_cfgs['res2net50_26w_6s']
|
||||
res2net_block_args = dict(scale=6)
|
||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs)
|
||||
return _create_res2net('res2net50_26w_6s', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net50_26w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Res2Net-50_26w_4s model.
|
||||
def res2net50_26w_8s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 26w8s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
default_cfg = default_cfgs['res2net50_26w_8s']
|
||||
res2net_block_args = dict(scale=8)
|
||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs)
|
||||
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net50_48w_2s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Res2Net-50_48w_2s model.
|
||||
def res2net50_48w_2s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 48w2s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
default_cfg = default_cfgs['res2net50_48w_2s']
|
||||
res2net_block_args = dict(scale=2)
|
||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=48,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs)
|
||||
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net50_14w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Res2Net-50_14w_8s model.
|
||||
def res2net50_14w_8s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 14w8s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
default_cfg = default_cfgs['res2net50_14w_8s']
|
||||
res2net_block_args = dict(scale=8)
|
||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=14, num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=res2net_block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs)
|
||||
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def res2next50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def res2next50(pretrained=False, **kwargs):
|
||||
"""Construct Res2NeXt-50 4s
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
default_cfg = default_cfgs['res2next50']
|
||||
res2net_block_args = dict(scale=4)
|
||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=4, cardinality=8,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs)
|
||||
return _create_res2net('res2next50', pretrained, **model_args)
|
||||
|
@ -6,18 +6,14 @@ Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang198
|
||||
|
||||
Modified for torchscript compat, and consistency with timm by Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.layers import DropBlock2d
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
||||
from .layers.split_attn import SplitAttnConv2d
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet
|
||||
from .resnet import _create_resnet_with_cfg
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -143,125 +139,98 @@ class ResNestBottleneck(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
def _create_resnest(variant, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs[variant]
|
||||
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnest14d(pretrained=False, **kwargs):
|
||||
""" ResNeSt-14d model. Weights ported from GluonCV.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest14d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans,
|
||||
model_kwargs = dict(
|
||||
block=ResNestBottleneck, layers=[1, 1, 1, 1],
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnest26d(pretrained=False, **kwargs):
|
||||
""" ResNeSt-26d model. Weights ported from GluonCV.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest26d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
||||
model_kwargs = dict(
|
||||
block=ResNestBottleneck, layers=[2, 2, 2, 2],
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnest50d(pretrained=False, **kwargs):
|
||||
""" ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
|
||||
Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest50d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
model_kwargs = dict(
|
||||
block=ResNestBottleneck, layers=[3, 4, 6, 3],
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnest101e(pretrained=False, **kwargs):
|
||||
""" ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
|
||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest101e']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
model_kwargs = dict(
|
||||
block=ResNestBottleneck, layers=[3, 4, 23, 3],
|
||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnest200e(pretrained=False, **kwargs):
|
||||
""" ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
|
||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest200e']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
model_kwargs = dict(
|
||||
block=ResNestBottleneck, layers=[3, 24, 36, 3],
|
||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnest269e(pretrained=False, **kwargs):
|
||||
""" ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
|
||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest269e']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans,
|
||||
model_kwargs = dict(
|
||||
block=ResNestBottleneck, layers=[3, 30, 48, 8],
|
||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest50d_4s2x40d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnest50d_4s2x40d(pretrained=False, **kwargs):
|
||||
"""ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest50d_4s2x40d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
model_kwargs = dict(
|
||||
block=ResNestBottleneck, layers=[3, 4, 6, 3],
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
|
||||
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnest50d_1s4x24d(pretrained=False, **kwargs):
|
||||
"""ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest50d_1s4x24d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
model_kwargs = dict(
|
||||
block=ResNestBottleneck, layers=[3, 4, 6, 3],
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
|
||||
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs)
|
||||
|
@ -6,11 +6,14 @@ additional dropout and dynamic global avg/max pool.
|
||||
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .features import FeatureNet
|
||||
from .helpers import load_pretrained, adapt_model_from_file
|
||||
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
|
||||
from .registry import register_model
|
||||
@ -390,6 +393,7 @@ class ResNet(nn.Module):
|
||||
self.base_width = base_width
|
||||
self.drop_rate = drop_rate
|
||||
self.expansion = block.expansion
|
||||
self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')]
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
# Stem
|
||||
@ -420,9 +424,6 @@ class ResNet(nn.Module):
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# Feature Blocks
|
||||
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
||||
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
|
||||
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
|
||||
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
||||
if output_stride == 16:
|
||||
strides[3] = 1
|
||||
@ -432,14 +433,23 @@ class ResNet(nn.Module):
|
||||
dilations[2:4] = [2, 4]
|
||||
else:
|
||||
assert output_stride == 32
|
||||
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
||||
db = [
|
||||
None, None,
|
||||
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
|
||||
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
|
||||
layer_args = list(zip(channels, layers, strides, dilations))
|
||||
layer_kwargs = dict(
|
||||
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
|
||||
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
||||
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
|
||||
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
|
||||
self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs)
|
||||
self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs)
|
||||
current_stride = 4
|
||||
for i in range(4):
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, self._make_layer(
|
||||
block, *layer_args[i], drop_block=db[i], **layer_kwargs))
|
||||
current_stride *= strides[i]
|
||||
self.feature_info.append(dict(
|
||||
num_chs=self.inplanes, reduction=current_stride, module=layer_name))
|
||||
|
||||
# Head (Pooling and Classifier)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
@ -509,245 +519,185 @@ class ResNet(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs):
|
||||
assert isinstance(default_cfg, dict)
|
||||
load_strict, features = True, False
|
||||
out_indices = None
|
||||
if kwargs.pop('features_only', False):
|
||||
load_strict, features = False, True
|
||||
kwargs.pop('num_classes', 0)
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
model = ResNet(**kwargs)
|
||||
model.default_cfg = copy.deepcopy(default_cfg)
|
||||
if kwargs.pop('pruned', False):
|
||||
model = adapt_model_from_file(model, variant)
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices=out_indices)
|
||||
return model
|
||||
|
||||
|
||||
def _create_resnet(variant, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs[variant]
|
||||
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet18']
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
|
||||
return _create_resnet('resnet18', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet34']
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
|
||||
return _create_resnet('resnet34', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnet26(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-26 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet26']
|
||||
model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs)
|
||||
return _create_resnet('resnet26', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnet26d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-26 v1d model.
|
||||
This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet26d']
|
||||
model = ResNet(
|
||||
Bottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_type='deep', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnet26d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet50']
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
||||
return _create_resnet('resnet50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnet50d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50-D model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet50d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnet50d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet101']
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
|
||||
return _create_resnet('resnet101', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnet152(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnet152']
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
|
||||
return _create_resnet('resnet152', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def tv_resnet34(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-34 model with original Torchvision weights.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfgs['tv_resnet34']
|
||||
if pretrained:
|
||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
|
||||
return _create_resnet('tv_resnet34', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def tv_resnet50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model with original Torchvision weights.
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfgs['tv_resnet50']
|
||||
if pretrained:
|
||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
||||
return _create_resnet('tv_resnet50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def wide_resnet50_2(pretrained=False, **kwargs):
|
||||
"""Constructs a Wide ResNet-50-2 model.
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
"""
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], base_width=128,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfgs['wide_resnet50_2']
|
||||
if pretrained:
|
||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs)
|
||||
return _create_resnet('wide_resnet50_2', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def wide_resnet101_2(pretrained=False, **kwargs):
|
||||
"""Constructs a Wide ResNet-101-2 model.
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same.
|
||||
"""
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], base_width=128,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfgs['wide_resnet101_2']
|
||||
if pretrained:
|
||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs)
|
||||
return _create_resnet('wide_resnet101_2', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnext50_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt50-32x4d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext50_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||
return _create_resnet('resnext50_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnext50d_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext50d_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnext50d_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnext101_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x4d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext101_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||
return _create_resnet('resnext101_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnext101_32x8d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x8d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext101_32x8d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
return _create_resnet('resnext101_32x8d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnext101_64x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt101-64x4d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext101_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs)
|
||||
return _create_resnet('resnext101_64x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def tv_resnext50_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt50-32x4d model with original Torchvision weights.
|
||||
"""
|
||||
default_cfg = default_cfgs['tv_resnext50_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||
return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -757,11 +707,8 @@ def ig_resnext101_32x8d(pretrained=True, **kwargs):
|
||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
model.default_cfg = default_cfgs['ig_resnext101_32x8d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -771,11 +718,8 @@ def ig_resnext101_32x16d(pretrained=True, **kwargs):
|
||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
model.default_cfg = default_cfgs['ig_resnext101_32x16d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -785,11 +729,8 @@ def ig_resnext101_32x32d(pretrained=True, **kwargs):
|
||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
|
||||
model.default_cfg = default_cfgs['ig_resnext101_32x32d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
|
||||
return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -799,11 +740,8 @@ def ig_resnext101_32x48d(pretrained=True, **kwargs):
|
||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
|
||||
model.default_cfg = default_cfgs['ig_resnext101_32x48d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
|
||||
return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -812,11 +750,8 @@ def ssl_resnet18(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnet18']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
|
||||
return _create_resnet('ssl_resnet18', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -825,11 +760,8 @@ def ssl_resnet50(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnet50']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
||||
return _create_resnet('ssl_resnet50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -838,11 +770,8 @@ def ssl_resnext50_32x4d(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnext50_32x4d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||
return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -851,11 +780,8 @@ def ssl_resnext101_32x4d(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnext101_32x4d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||
return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -864,11 +790,8 @@ def ssl_resnext101_32x8d(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnext101_32x8d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -877,11 +800,8 @@ def ssl_resnext101_32x16d(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
model.default_cfg = default_cfgs['ssl_resnext101_32x16d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -891,11 +811,8 @@ def swsl_resnet18(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnet18']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
|
||||
return _create_resnet('swsl_resnet18', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -905,11 +822,8 @@ def swsl_resnet50(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnet50']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
||||
return _create_resnet('swsl_resnet50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -919,11 +833,8 @@ def swsl_resnext50_32x4d(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnext50_32x4d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||
return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -933,11 +844,8 @@ def swsl_resnext101_32x4d(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnext101_32x4d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||
return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -947,11 +855,8 @@ def swsl_resnext101_32x8d(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnext101_32x8d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||
return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -961,61 +866,44 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs):
|
||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
model.default_cfg = default_cfgs['swsl_resnext101_32x16d']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||
return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def seresnext26d_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNeXt-26-D model.
|
||||
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
|
||||
combination of deep stem and avg_pool in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['seresnext26d_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26d_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def seresnext26t_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNet-26-T model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
|
||||
in the deep stem.
|
||||
"""
|
||||
default_cfg = default_cfgs['seresnext26t_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
||||
stem_width=32, stem_type='deep_tiered', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26t_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def seresnext26tn_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE-ResNeXt-26-TN model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||
"""
|
||||
default_cfg = default_cfgs['seresnext26tn_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
||||
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
@ -1025,145 +913,91 @@ def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwarg
|
||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||
this model replaces SE module with the ECA module
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
|
||||
block_args = dict(attn_layer='eca')
|
||||
model = ResNet(
|
||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
||||
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ecaresnet18(pretrained=False, **kwargs):
|
||||
""" Constructs an ECA-ResNet-18 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnet18']
|
||||
block_args = dict(attn_layer='eca')
|
||||
model = ResNet(
|
||||
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnet18', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ecaresnet50(pretrained=False, **kwargs):
|
||||
"""Constructs an ECA-ResNet-50 model.
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnet50']
|
||||
block_args = dict(attn_layer='eca')
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnet50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ecaresnet50d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50-D model with eca.
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnet50d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnet50d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet50d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ecaresnet50d_pruned(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50-D model pruned with eca.
|
||||
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
||||
"""
|
||||
variant = 'ecaresnet50d_pruned'
|
||||
default_cfg = default_cfgs[variant]
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnetlight(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ecaresnetlight(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50-D light model with eca.
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnetlight']
|
||||
model = ResNet(
|
||||
Bottleneck, [1, 1, 11, 3], stem_width=32, avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True,
|
||||
block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnetlight', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet101d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ecaresnet101d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101-D model with eca.
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnet101d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnet101d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def ecaresnet101d_pruned(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101-D model pruned with eca.
|
||||
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
||||
"""
|
||||
variant = 'ecaresnet101d_pruned'
|
||||
default_cfg = default_cfgs[variant]
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnetblur18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model with blur anti-aliasing
|
||||
"""
|
||||
default_cfg = default_cfgs['resnetblur18']
|
||||
model = ResNet(
|
||||
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
|
||||
return _create_resnet('resnetblur18', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def resnetblur50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model with blur anti-aliasing
|
||||
"""
|
||||
default_cfg = default_cfgs['resnetblur50']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
|
||||
return _create_resnet('resnetblur50', pretrained, **model_args)
|
||||
|
@ -16,6 +16,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .features import FeatureNet
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from .registry import register_model
|
||||
@ -100,7 +101,8 @@ class SelecSLSBlock(nn.Module):
|
||||
self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
|
||||
|
||||
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
assert isinstance(x, list)
|
||||
if not isinstance(x, list):
|
||||
x = [x]
|
||||
assert len(x) in [1, 2]
|
||||
|
||||
d1 = self.conv1(x[0])
|
||||
@ -163,7 +165,7 @@ class SelecSLS(nn.Module):
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.features([x])
|
||||
x = self.features(x)
|
||||
x = self.head(x[0])
|
||||
return x
|
||||
|
||||
@ -178,6 +180,7 @@ class SelecSLS(nn.Module):
|
||||
|
||||
def _create_model(variant, pretrained, model_kwargs):
|
||||
cfg = {}
|
||||
feature_info = [dict(num_chs=32, reduction=2, module='stem.2')]
|
||||
if variant.startswith('selecsls42'):
|
||||
cfg['block'] = SelecSLSBlock
|
||||
# Define configuration of the network after the initial neck
|
||||
@ -190,7 +193,13 @@ def _create_model(variant, pretrained, model_kwargs):
|
||||
(288, 0, 304, 304, True, 2),
|
||||
(304, 304, 304, 480, False, 1),
|
||||
]
|
||||
feature_info.extend([
|
||||
dict(num_chs=128, reduction=4, module='features.1'),
|
||||
dict(num_chs=288, reduction=8, module='features.3'),
|
||||
dict(num_chs=480, reduction=16, module='features.5'),
|
||||
])
|
||||
# Head can be replaced with alternative configurations depending on the problem
|
||||
feature_info.append(dict(num_chs=1024, reduction=32, module='head.1'))
|
||||
if variant == 'selecsls42b':
|
||||
cfg['head'] = [
|
||||
(480, 960, 3, 2),
|
||||
@ -198,6 +207,7 @@ def _create_model(variant, pretrained, model_kwargs):
|
||||
(1024, 1280, 3, 2),
|
||||
(1280, 1024, 1, 1),
|
||||
]
|
||||
feature_info.append(dict(num_chs=1024, reduction=64, module='head.3'))
|
||||
cfg['num_features'] = 1024
|
||||
else:
|
||||
cfg['head'] = [
|
||||
@ -206,7 +216,9 @@ def _create_model(variant, pretrained, model_kwargs):
|
||||
(1024, 1024, 3, 2),
|
||||
(1024, 1280, 1, 1),
|
||||
]
|
||||
feature_info.append(dict(num_chs=1280, reduction=64, module='head.3'))
|
||||
cfg['num_features'] = 1280
|
||||
|
||||
elif variant.startswith('selecsls60'):
|
||||
cfg['block'] = SelecSLSBlock
|
||||
# Define configuration of the network after the initial neck
|
||||
@ -222,7 +234,13 @@ def _create_model(variant, pretrained, model_kwargs):
|
||||
(288, 288, 288, 288, False, 1),
|
||||
(288, 288, 288, 416, False, 1),
|
||||
]
|
||||
feature_info.extend([
|
||||
dict(num_chs=128, reduction=4, module='features.1'),
|
||||
dict(num_chs=288, reduction=8, module='features.4'),
|
||||
dict(num_chs=416, reduction=16, module='features.8'),
|
||||
])
|
||||
# Head can be replaced with alternative configurations depending on the problem
|
||||
feature_info.append(dict(num_chs=1024, reduction=32, module='head.1'))
|
||||
if variant == 'selecsls60b':
|
||||
cfg['head'] = [
|
||||
(416, 756, 3, 2),
|
||||
@ -230,6 +248,7 @@ def _create_model(variant, pretrained, model_kwargs):
|
||||
(1024, 1280, 3, 2),
|
||||
(1280, 1024, 1, 1),
|
||||
]
|
||||
feature_info.append(dict(num_chs=1024, reduction=64, module='head.3'))
|
||||
cfg['num_features'] = 1024
|
||||
else:
|
||||
cfg['head'] = [
|
||||
@ -238,7 +257,9 @@ def _create_model(variant, pretrained, model_kwargs):
|
||||
(1024, 1024, 3, 2),
|
||||
(1024, 1280, 1, 1),
|
||||
]
|
||||
feature_info.append(dict(num_chs=1280, reduction=64, module='head.3'))
|
||||
cfg['num_features'] = 1280
|
||||
|
||||
elif variant == 'selecsls84':
|
||||
cfg['block'] = SelecSLSBlock
|
||||
# Define configuration of the network after the initial neck
|
||||
@ -258,6 +279,11 @@ def _create_model(variant, pretrained, model_kwargs):
|
||||
(304, 304, 304, 304, False, 1),
|
||||
(304, 304, 304, 512, False, 1),
|
||||
]
|
||||
feature_info.extend([
|
||||
dict(num_chs=144, reduction=4, module='features.1'),
|
||||
dict(num_chs=304, reduction=8, module='features.6'),
|
||||
dict(num_chs=512, reduction=16, module='features.12'),
|
||||
])
|
||||
# Head can be replaced with alternative configurations depending on the problem
|
||||
cfg['head'] = [
|
||||
(512, 960, 3, 2),
|
||||
@ -266,17 +292,35 @@ def _create_model(variant, pretrained, model_kwargs):
|
||||
(1024, 1280, 3, 1),
|
||||
]
|
||||
cfg['num_features'] = 1280
|
||||
feature_info.extend([
|
||||
dict(num_chs=1024, reduction=32, module='head.1'),
|
||||
dict(num_chs=1280, reduction=64, module='head.3')
|
||||
])
|
||||
else:
|
||||
raise ValueError('Invalid net configuration ' + variant + ' !!!')
|
||||
|
||||
load_strict = True
|
||||
features = False
|
||||
out_indices = None
|
||||
if model_kwargs.pop('features_only', False):
|
||||
load_strict = False
|
||||
features = True
|
||||
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
|
||||
out_indices = model_kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
|
||||
model = SelecSLS(cfg, **model_kwargs)
|
||||
model.default_cfg = default_cfgs[variant]
|
||||
model.feature_info = feature_info
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=model_kwargs.get('num_classes', 0),
|
||||
in_chans=model_kwargs.get('in_chans', 3),
|
||||
strict=True)
|
||||
strict=load_strict)
|
||||
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices, flatten_sequential=True)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -12,11 +12,11 @@ import math
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from .registry import register_model
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
||||
from .resnet import ResNet
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .registry import register_model
|
||||
from .resnet import _create_resnet_with_cfg
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -138,101 +138,80 @@ class SelectiveKernelBottleneck(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _create_skresnet(variant, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs[variant]
|
||||
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def skresnet18(pretrained=False, **kwargs):
|
||||
"""Constructs a Selective Kernel ResNet-18 model.
|
||||
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
default_cfg = default_cfgs['skresnet18']
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
attn_reduction=8,
|
||||
split_input=True
|
||||
)
|
||||
model = ResNet(
|
||||
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
return _create_skresnet('skresnet18', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def skresnet34(pretrained=False, **kwargs):
|
||||
"""Constructs a Selective Kernel ResNet-34 model.
|
||||
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
default_cfg = default_cfgs['skresnet34']
|
||||
sk_kwargs = dict(
|
||||
min_attn_channels=16,
|
||||
attn_reduction=8,
|
||||
split_input=True
|
||||
)
|
||||
model = ResNet(
|
||||
SelectiveKernelBasic, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
return _create_skresnet('skresnet34', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def skresnet50(pretrained=False, **kwargs):
|
||||
"""Constructs a Select Kernel ResNet-50 model.
|
||||
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
sk_kwargs = dict(
|
||||
split_input=True,
|
||||
)
|
||||
default_cfg = default_cfgs['skresnet50']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
sk_kwargs = dict(split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
return _create_skresnet('skresnet50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def skresnet50d(pretrained=False, **kwargs):
|
||||
"""Constructs a Select Kernel ResNet-50-D model.
|
||||
|
||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||
variation splits the input channels to the selective convolutions to keep param count down.
|
||||
"""
|
||||
sk_kwargs = dict(
|
||||
split_input=True,
|
||||
)
|
||||
default_cfg = default_cfgs['skresnet50d']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
sk_kwargs = dict(split_input=True)
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||
return _create_skresnet('skresnet50d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def skresnext50_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
||||
the SKNet-50 model in the Select Kernel Paper
|
||||
"""
|
||||
default_cfg = default_cfgs['skresnext50_32x4d']
|
||||
model = ResNet(
|
||||
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
model_args = dict(
|
||||
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
zero_init_last_bn=False, **kwargs)
|
||||
return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)
|
||||
|
||||
|
@ -20,6 +20,7 @@ import torch.nn.functional as F
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .features import FeatureNet
|
||||
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \
|
||||
create_attn, create_norm_act, get_norm_act_layer
|
||||
|
||||
@ -296,6 +297,9 @@ class VovNet(nn.Module):
|
||||
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, norm_layer=norm_layer),
|
||||
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, norm_layer=norm_layer),
|
||||
])
|
||||
self.feature_info = [dict(
|
||||
num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')]
|
||||
current_stride = stem_stride
|
||||
|
||||
# OSA stages
|
||||
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
|
||||
@ -309,6 +313,9 @@ class VovNet(nn.Module):
|
||||
downsample=downsample, **stage_args)
|
||||
]
|
||||
self.num_features = stage_out_chs[i]
|
||||
current_stride *= 2 if downsample else 1
|
||||
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
@ -338,24 +345,24 @@ class VovNet(nn.Module):
|
||||
|
||||
|
||||
def _vovnet(variant, pretrained=False, **kwargs):
|
||||
load_strict = True
|
||||
model_class = VovNet
|
||||
features = False
|
||||
out_indices = None
|
||||
if kwargs.pop('features_only', False):
|
||||
assert False, 'Not Implemented' # TODO
|
||||
load_strict = False
|
||||
features = True
|
||||
kwargs.pop('num_classes', 0)
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
model_cfg = model_cfgs[variant]
|
||||
default_cfg = default_cfgs[variant]
|
||||
model = model_class(model_cfg, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
model = VovNet(model_cfg, **kwargs)
|
||||
model.default_cfg = default_cfgs[variant]
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices, flatten_sequential=True)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def vovnet39a(pretrained=False, **kwargs):
|
||||
return _vovnet('vovnet39a', pretrained=pretrained, **kwargs)
|
||||
|
@ -26,6 +26,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import load_pretrained
|
||||
from .features import FeatureNet
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from .registry import register_model
|
||||
|
||||
@ -49,12 +50,12 @@ default_cfgs = {
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)
|
||||
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
|
||||
in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False)
|
||||
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
@ -63,34 +64,26 @@ class SeparableConv2d(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
|
||||
def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True):
|
||||
super(Block, self).__init__()
|
||||
|
||||
if out_filters != in_filters or strides != 1:
|
||||
self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
|
||||
self.skipbn = nn.BatchNorm2d(out_filters)
|
||||
if out_channels != in_channels or strides != 1:
|
||||
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False)
|
||||
self.skipbn = nn.BatchNorm2d(out_channels)
|
||||
else:
|
||||
self.skip = None
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
rep = []
|
||||
|
||||
filters = in_filters
|
||||
if grow_first:
|
||||
rep.append(self.relu)
|
||||
rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
|
||||
rep.append(nn.BatchNorm2d(out_filters))
|
||||
filters = out_filters
|
||||
|
||||
for i in range(reps - 1):
|
||||
rep.append(self.relu)
|
||||
rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
|
||||
rep.append(nn.BatchNorm2d(filters))
|
||||
|
||||
if not grow_first:
|
||||
rep.append(self.relu)
|
||||
rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
|
||||
rep.append(nn.BatchNorm2d(out_filters))
|
||||
for i in range(reps):
|
||||
if grow_first:
|
||||
inc = in_channels if i == 0 else out_channels
|
||||
outc = out_channels
|
||||
else:
|
||||
inc = in_channels
|
||||
outc = in_channels if i < (reps - 1) else out_channels
|
||||
rep.append(nn.ReLU(inplace=True))
|
||||
rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1))
|
||||
rep.append(nn.BatchNorm2d(outc))
|
||||
|
||||
if not start_with_relu:
|
||||
rep = rep[1:]
|
||||
@ -133,34 +126,35 @@ class Xception(nn.Module):
|
||||
|
||||
self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.act1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(64)
|
||||
# do relu here
|
||||
self.act2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
|
||||
self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
|
||||
self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
|
||||
self.block1 = Block(64, 128, 2, 2, start_with_relu=False)
|
||||
self.block2 = Block(128, 256, 2, 2)
|
||||
self.block3 = Block(256, 728, 2, 2)
|
||||
|
||||
self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
||||
self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
||||
self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
||||
self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
||||
self.block4 = Block(728, 728, 3, 1)
|
||||
self.block5 = Block(728, 728, 3, 1)
|
||||
self.block6 = Block(728, 728, 3, 1)
|
||||
self.block7 = Block(728, 728, 3, 1)
|
||||
|
||||
self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
||||
self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
||||
self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
||||
self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
||||
self.block8 = Block(728, 728, 3, 1)
|
||||
self.block9 = Block(728, 728, 3, 1)
|
||||
self.block10 = Block(728, 728, 3, 1)
|
||||
self.block11 = Block(728, 728, 3, 1)
|
||||
|
||||
self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
|
||||
self.block12 = Block(728, 1024, 2, 2, grow_first=False)
|
||||
|
||||
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
|
||||
self.bn3 = nn.BatchNorm2d(1536)
|
||||
self.act3 = nn.ReLU(inplace=True)
|
||||
|
||||
# do relu here
|
||||
self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1)
|
||||
self.bn4 = nn.BatchNorm2d(self.num_features)
|
||||
self.act4 = nn.ReLU(inplace=True)
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
@ -188,11 +182,11 @@ class Xception(nn.Module):
|
||||
def forward_features(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.act1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
x = self.act2(x)
|
||||
|
||||
x = self.block1(x)
|
||||
x = self.block2(x)
|
||||
@ -209,11 +203,11 @@ class Xception(nn.Module):
|
||||
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
x = self.relu(x)
|
||||
x = self.act3(x)
|
||||
|
||||
x = self.conv4(x)
|
||||
x = self.bn4(x)
|
||||
x = self.relu(x)
|
||||
x = self.act4(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
@ -225,12 +219,28 @@ class Xception(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def xception(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['xception']
|
||||
model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
def _xception(variant, pretrained=False, **kwargs):
|
||||
load_strict = True
|
||||
features = False
|
||||
out_indices = None
|
||||
if kwargs.pop('features_only', False):
|
||||
load_strict = False
|
||||
features = True
|
||||
kwargs.pop('num_classes', 0)
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||
model = Xception(**kwargs)
|
||||
model.default_cfg = default_cfgs[variant]
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
strict=load_strict)
|
||||
if features:
|
||||
model = FeatureNet(model, out_indices)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xception(pretrained=False, **kwargs):
|
||||
return _xception('xception', pretrained=pretrained, **kwargs)
|
||||
|
60
validate.py
60
validate.py
@ -24,9 +24,8 @@ try:
|
||||
except ImportError:
|
||||
has_apex = False
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models,\
|
||||
set_scriptable, set_no_jit
|
||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
@ -76,8 +75,25 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
|
||||
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||
help='Real labels JSON file for imagenet evaluation')
|
||||
|
||||
|
||||
def set_jit_legacy():
|
||||
""" Set JIT executor to legacy w/ support for op fusion
|
||||
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
|
||||
in the JIT exectutor. These API are not supported so could change.
|
||||
"""
|
||||
#
|
||||
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
#torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
|
||||
|
||||
def validate(args):
|
||||
@ -103,6 +119,8 @@ def validate(args):
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||
|
||||
if args.torchscript:
|
||||
if args.legacy_jit:
|
||||
set_jit_legacy()
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
|
||||
@ -116,13 +134,16 @@ def validate(args):
|
||||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
#from torchvision.datasets import ImageNet
|
||||
#dataset = ImageNet(args.data, split='val')
|
||||
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
|
||||
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
else:
|
||||
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
|
||||
if args.real_labels:
|
||||
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
|
||||
else:
|
||||
real_labels = None
|
||||
|
||||
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
||||
loader = create_loader(
|
||||
dataset,
|
||||
@ -148,7 +169,7 @@ def validate(args):
|
||||
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
|
||||
model(input)
|
||||
end = time.time()
|
||||
for i, (input, target) in enumerate(loader):
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
if args.no_prefetcher:
|
||||
target = target.cuda()
|
||||
input = input.cuda()
|
||||
@ -159,6 +180,9 @@ def validate(args):
|
||||
output = model(input)
|
||||
loss = criterion(output, target)
|
||||
|
||||
if real_labels is not None:
|
||||
real_labels.add_result(output)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
|
||||
losses.update(loss.item(), input.size(0))
|
||||
@ -169,25 +193,35 @@ def validate(args):
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.log_freq == 0:
|
||||
if batch_idx % args.log_freq == 0:
|
||||
logging.info(
|
||||
'Test: [{0:>4d}/{1}] '
|
||||
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
||||
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time,
|
||||
batch_idx, len(loader), batch_time=batch_time,
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
|
||||
results = OrderedDict(
|
||||
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
|
||||
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
|
||||
if real_labels is not None:
|
||||
real_top1 = real_labels.get_accuracy(k=1)
|
||||
real_top5 = real_labels.get_accuracy(k=5)
|
||||
results = OrderedDict(
|
||||
top1=round(real_top1, 4), top1_err=round(100 - real_top1, 4),
|
||||
top5=round(real_top5, 4), top5_err=round(100 - real_top5, 4),
|
||||
top1_original=round(top1.avg, 4),
|
||||
top5_original=round(top5.avg, 4))
|
||||
else:
|
||||
results = OrderedDict(
|
||||
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
|
||||
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4))
|
||||
results.update(OrderedDict(
|
||||
param_count=round(param_count / 1e6, 2),
|
||||
img_size=data_config['input_size'][-1],
|
||||
cropt_pct=crop_pct,
|
||||
interpolation=data_config['interpolation'])
|
||||
|
||||
interpolation=data_config['interpolation']
|
||||
))
|
||||
logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user